From dfe7bd168d9bcf8c53f993f459ab473d893457b0 Mon Sep 17 00:00:00 2001 From: Joseph Batchik Date: Mon, 3 Aug 2015 11:17:38 -0700 Subject: [PATCH 0001/1824] [SPARK-9511] [SQL] Fixed Table Name Parsing The issue was that the tokenizer was parsing "1one" into the numeric 1 using the code on line 110. I added another case to accept strings that start with a number and then have a letter somewhere else in it as well. Author: Joseph Batchik Closes #7844 from JDrit/parse_error and squashes the following commits: b8ca12f [Joseph Batchik] fixed parsing issue by adding another case --- .../spark/sql/catalyst/AbstractSparkSQLParser.scala | 2 ++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index d494ae7b71d1..5898a5f93f38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -104,6 +104,8 @@ class SqlLexical extends StdLexical { override lazy val token: Parser[Token] = ( identChar ~ (identChar | digit).* ^^ { case first ~ rest => processIdent((first :: rest).mkString) } + | digit.* ~ identChar ~ (identChar | digit).* ^^ + { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { case i ~ None => NumericLit(i.mkString) case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bbadc202a4f0..f1abae072005 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1604,4 +1604,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(df.select(-df("i")), Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } + + test("SPARK-9511: error with table starting with number") { + val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + .toDF("num", "str") + df.registerTempTable("1one") + + checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10)) + + sqlContext.dropTempTable("1one") + } } From 7a9d09f0bb472a1671d3457e1f7108f4c2eb4121 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 3 Aug 2015 11:22:02 -0700 Subject: [PATCH 0002/1824] [SQL][minor] Simplify UnsafeRow.calculateBitSetWidthInBytes. Author: Reynold Xin Closes #7897 from rxin/calculateBitSetWidthInBytes and squashes the following commits: 2e73b3a [Reynold Xin] [SQL][minor] Simplify UnsafeRow.calculateBitSetWidthInBytes. --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 2 +- .../scala/org/apache/spark/sql/UnsafeRowSuite.scala | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index f4230cfaba37..e6750fce4fa8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -59,7 +59,7 @@ public final class UnsafeRow extends MutableRow { ////////////////////////////////////////////////////////////////////////////// public static int calculateBitSetWidthInBytes(int numFields) { - return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; + return ((numFields + 63)/ 64) * 8; } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index c5faaa663e74..89bad1bfdab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -28,6 +28,16 @@ import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String class UnsafeRowSuite extends SparkFunSuite { + + test("bitset width calculation") { + assert(UnsafeRow.calculateBitSetWidthInBytes(0) === 0) + assert(UnsafeRow.calculateBitSetWidthInBytes(1) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(32) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(64) === 8) + assert(UnsafeRow.calculateBitSetWidthInBytes(65) === 16) + assert(UnsafeRow.calculateBitSetWidthInBytes(128) === 16) + } + test("writeToStream") { val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) val arrayBackedUnsafeRow: UnsafeRow = From 703e44bff19f4c394f6f9bff1ce9152cdc68c51e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 3 Aug 2015 12:06:58 -0700 Subject: [PATCH 0003/1824] [SPARK-9554] [SQL] Enables in-memory partition pruning by default Author: Cheng Lian Closes #7895 from liancheng/spark-9554/enable-in-memory-partition-pruning and squashes the following commits: 67c403e [Cheng Lian] Enables in-memory partition pruning by default --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 387960c4b482..41ba1c7fe057 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -200,7 +200,7 @@ private[spark] object SQLConf { val IN_MEMORY_PARTITION_PRUNING = booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, enable partition pruning for in-memory columnar tables.", isPublic = false) From ff9169a002f1b75231fd25b7d04157a912503038 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 3 Aug 2015 12:17:46 -0700 Subject: [PATCH 0004/1824] [SPARK-5133] [ML] Added featureImportance to RandomForestClassifier and Regressor Added featureImportance to RandomForestClassifier and Regressor. This follows the scikit-learn implementation here: [https://github.com/scikit-learn/scikit-learn/blob/a95203b249c1cf392f86d001ad999e29b2392739/sklearn/tree/_tree.pyx#L3341] CC: yanboliang Would you mind taking a look? Thanks! Author: Joseph K. Bradley Author: Feynman Liang Closes #7838 from jkbradley/dt-feature-importance and squashes the following commits: 72a167a [Joseph K. Bradley] fixed unit test 86cea5f [Joseph K. Bradley] Modified RF featuresImportances to return Vector instead of Map 5aa74f0 [Joseph K. Bradley] finally fixed unit test for real 33df5db [Joseph K. Bradley] fix unit test 42a2d3b [Joseph K. Bradley] fix unit test fe94e72 [Joseph K. Bradley] modified feature importance unit tests cc693ee [Feynman Liang] Add classifier tests 79a6f87 [Feynman Liang] Compare dense vectors in test 21d01fc [Feynman Liang] Added failing SKLearn test ac0b254 [Joseph K. Bradley] Added featureImportance to RandomForestClassifier/Regressor. Need to add unit tests --- .../RandomForestClassifier.scala | 30 ++++- .../ml/regression/RandomForestRegressor.scala | 33 +++++- .../scala/org/apache/spark/ml/tree/Node.scala | 19 +++- .../spark/ml/tree/impl/RandomForest.scala | 92 +++++++++++++++ .../org/apache/spark/ml/tree/treeModels.scala | 6 + .../JavaRandomForestClassifierSuite.java | 2 + .../JavaRandomForestRegressorSuite.java | 2 + .../RandomForestClassifierSuite.scala | 31 ++++- .../org/apache/spark/ml/impl/TreeTests.scala | 18 +++ .../RandomForestRegressorSuite.scala | 27 ++++- .../ml/tree/impl/RandomForestSuite.scala | 107 ++++++++++++++++++ 11 files changed, 351 insertions(+), 16 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 56e80cc8fe6e..b59826a59499 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -95,7 +95,8 @@ final class RandomForestClassifier(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - new RandomForestClassificationModel(trees, numClasses) + val numFeatures = oldDataset.first().features.size + new RandomForestClassificationModel(trees, numFeatures, numClasses) } override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) @@ -118,11 +119,13 @@ object RandomForestClassifier { * features. * @param _trees Decision trees in the ensemble. * Warning: These have null parents. + * @param numFeatures Number of features used by this model */ @Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], + val numFeatures: Int, override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -133,8 +136,8 @@ final class RandomForestClassificationModel private[ml] ( * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) = - this(Identifiable.randomUID("rfc"), trees, numClasses) + def this(trees: Array[DecisionTreeClassificationModel], numFeatures: Int, numClasses: Int) = + this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] @@ -182,13 +185,30 @@ final class RandomForestClassificationModel private[ml] ( } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) + copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) } override def toString: String = { s"RandomForestClassificationModel with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + */ + lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) @@ -210,6 +230,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees, numClasses) + new RandomForestClassificationModel(uid, newTrees, -1, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 17fb1ad5e15d..1ee43c872573 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType + /** * :: Experimental :: @@ -87,7 +87,8 @@ final class RandomForestRegressor(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeRegressionModel]) - new RandomForestRegressionModel(trees) + val numFeatures = oldDataset.first().features.size + new RandomForestRegressionModel(trees, numFeatures) } override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) @@ -108,11 +109,13 @@ object RandomForestRegressor { * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. * @param _trees Decision trees in the ensemble. + * @param numFeatures Number of features used by this model */ @Experimental final class RandomForestRegressionModel private[ml] ( override val uid: String, - private val _trees: Array[DecisionTreeRegressionModel]) + private val _trees: Array[DecisionTreeRegressionModel], + val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] with TreeEnsembleModel with Serializable { @@ -122,7 +125,8 @@ final class RandomForestRegressionModel private[ml] ( * Construct a random forest regression model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees) + def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = + this(Identifiable.randomUID("rfr"), trees, numFeatures) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] @@ -147,13 +151,30 @@ final class RandomForestRegressionModel private[ml] ( } override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(uid, _trees), extra) + copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra) } override def toString: String = { s"RandomForestRegressionModel with $numTrees trees" } + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + */ + lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) @@ -173,6 +194,6 @@ private[ml] object RandomForestRegressionModel { // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent.uid, newTrees) + new RandomForestRegressionModel(parent.uid, newTrees, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 8879352a600a..cd2493129390 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -44,7 +44,7 @@ sealed abstract class Node extends Serializable { * and probabilities. * For classification, the array of class counts must be normalized to a probability distribution. */ - private[tree] def impurityStats: ImpurityCalculator + private[ml] def impurityStats: ImpurityCalculator /** Recursive prediction helper method */ private[ml] def predictImpl(features: Vector): LeafNode @@ -72,6 +72,12 @@ sealed abstract class Node extends Serializable { * @param id Node ID using old format IDs */ private[ml] def toOld(id: Int): OldNode + + /** + * Trace down the tree, and return the largest feature index used in any split. + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int } private[ml] object Node { @@ -109,7 +115,7 @@ private[ml] object Node { final class LeafNode private[ml] ( override val prediction: Double, override val impurity: Double, - override val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) extends Node { override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -129,6 +135,8 @@ final class LeafNode private[ml] ( new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, isLeaf = true, None, None, None, None) } + + override private[ml] def maxSplitFeatureIndex(): Int = -1 } /** @@ -150,7 +158,7 @@ final class InternalNode private[ml] ( val leftChild: Node, val rightChild: Node, val split: Split, - override val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) extends Node { override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -190,6 +198,11 @@ final class InternalNode private[ml] ( new OldPredict(leftChild.prediction, prob = 0.0), new OldPredict(rightChild.prediction, prob = 0.0)))) } + + override private[ml] def maxSplitFeatureIndex(): Int = { + math.max(split.featureIndex, + math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) + } } private object InternalNode { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index a8b90d9d266a..4ac51a475474 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -26,6 +26,7 @@ import org.apache.spark.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, @@ -34,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -1113,4 +1115,94 @@ private[ml] object RandomForest extends Logging { } } + /** + * Given a Random Forest model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree based on total number of training instances used + * to build tree. + * - Normalize feature importance vector to sum to 1. + * + * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for + * independently trained trees. + * @param trees Unweighted forest of trees + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { + val totalImportances = new OpenHashMap[Int, Double]() + trees.foreach { tree => + // Aggregate feature importance vector for this tree + val importances = new OpenHashMap[Int, Double]() + computeFeatureImportance(tree.rootNode, importances) + // Normalize importance vector for this tree, and add it to total. + // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? + val treeNorm = importances.map(_._2).sum + if (treeNorm != 0) { + importances.foreach { case (idx, impt) => + val normImpt = impt / treeNorm + totalImportances.changeValue(idx, normImpt, _ + normImpt) + } + } + } + // Normalize importances + normalizeMapValues(totalImportances) + // Construct vector + val d = if (numFeatures != -1) { + numFeatures + } else { + // Find max feature index used in trees + val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max + maxFeatureIndex + 1 + } + if (d == 0) { + assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" + + s" importance: No splits in forest, but some non-zero importances.") + } + val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip + Vectors.sparse(d, indices.toArray, values.toArray) + } + + /** + * Recursive method for computing feature importances for one tree. + * This walks down the tree, adding to the importance of 1 feature at each node. + * @param node Current node in recursion + * @param importances Aggregate feature importances, modified by this method + */ + private[impl] def computeFeatureImportance( + node: Node, + importances: OpenHashMap[Int, Double]): Unit = { + node match { + case n: InternalNode => + val feature = n.split.featureIndex + val scaledGain = n.gain * n.impurityStats.count + importances.changeValue(feature, scaledGain, _ + scaledGain) + computeFeatureImportance(n.leftChild, importances) + computeFeatureImportance(n.rightChild, importances) + case n: LeafNode => + // do nothing + } + } + + /** + * Normalize the values of this map to sum to 1, in place. + * If all values are 0, this method does nothing. + * @param map Map with non-negative values. + */ + private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = { + val total = map.map(_._2).sum + if (total != 0) { + val keys = map.iterator.map(_._1).toArray + keys.foreach { key => map.changeValue(key, 0.0, _ / total) } + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 22873909c33f..b77191156f68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -53,6 +53,12 @@ private[ml] trait DecisionTreeModel { val header = toString + "\n" header + rootNode.subtreeToString(2) } + + /** + * Trace down the tree, and return the largest feature index used in any split. + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 32d0b3856b7e..a66a1e12927b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public void runDT() { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index e306ebadfe7c..a00ce5e249c3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; @@ -85,6 +86,7 @@ public void runDT() { model.toDebugString(); model.trees(); model.treeWeights(); + Vector importances = model.featureImportances(); /* // TODO: Add test once save/load are implemented. SPARK-6725 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index edf848b21a90..6ca4b5aa5fde 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -67,7 +67,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2) ParamsSuite.checkParams(model) } @@ -149,6 +149,35 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } } + ///////////////////////////////////////////////////////////////////////////// + // Tests of feature importance + ///////////////////////////////////////////////////////////////////////////// + test("Feature importance with toy data") { + val numClasses = 2 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + + val importances = rf.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 778abcba22c1..460849c79f04 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite { "checkEqual failed since the two tree ensembles were not identical") } } + + /** + * Helper method for constructing a tree for testing. + * Given left, right children, construct a parent node. + * @param split Split for parent node + * @return Parent node with children attached + */ + def buildParentNode(left: Node, right: Node, split: Split): Node = { + val leftImp = left.impurityStats + val rightImp = right.impurityStats + val parentImp = leftImp.copy.add(rightImp) + val leftWeight = leftImp.count / parentImp.count.toDouble + val rightWeight = rightImp.count / parentImp.count.toDouble + val gain = parentImp.calculate() - + (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) + val pred = parentImp.predict + new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index b24ecaa57c89..992ce9562434 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -26,7 +27,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * Test suite for [[RandomForestRegressor]]. */ @@ -71,6 +71,31 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex regressionTestWithContinuousFeatures(rf) } + test("Feature importance with toy data") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setNumTrees(3) + .setFeatureSubsetStrategy("all") + .setSubsamplingRate(1.0) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val importances = rf.fit(df).featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala new file mode 100644 index 000000000000..dc852795c7f6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.tree.impurity.GiniCalculator +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.collection.OpenHashMap + +/** + * Test suite for [[RandomForest]]. + */ +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { + + import RandomForestSuite.mapToVec + + test("computeFeatureImportance, featureImportances") { + /* Build tree for testing, with this structure: + grandParent + left2 parent + left right + */ + val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) + val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + + val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) + val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parentImp = parent.impurityStats + + val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) + val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandImp = grandParent.impurityStats + + // Test feature importance computed at different subtrees. + def testNode(node: Node, expected: Map[Int, Double]): Unit = { + val map = new OpenHashMap[Int, Double]() + RandomForest.computeFeatureImportance(node, map) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + + // Leaf node + testNode(left, Map.empty[Int, Double]) + + // Internal node with 2 leaf children + val feature0importance = parentImp.calculate() * parentImp.count - + (leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count) + testNode(parent, Map(0 -> feature0importance)) + + // Full tree + val feature1importance = grandImp.calculate() * grandImp.count - + (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count) + testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance)) + + // Forest consisting of (full tree) + (internal node with 2 leafs) + val trees = Array(parent, grandParent).map { root => + new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel] + } + val importances: Vector = RandomForest.featureImportances(trees, 2) + val tree2norm = feature0importance + feature1importance + val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0, + (feature1importance / tree2norm) / 2.0) + assert(importances ~== expected relTol 0.01) + } + + test("normalizeMapValues") { + val map = new OpenHashMap[Int, Double]() + map(0) = 1.0 + map(2) = 2.0 + RandomForest.normalizeMapValues(map) + val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + +} + +private object RandomForestSuite { + + def mapToVec(map: Map[Int, Double]): Vector = { + val size = (map.keys.toSeq :+ 0).max + 1 + val (indices, values) = map.toSeq.sortBy(_._1).unzip + Vectors.sparse(size, indices.toArray, values.toArray) + } +} From ba1c4e138de2ea84b55def4eed2bd363e60aea4d Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 3 Aug 2015 12:53:44 -0700 Subject: [PATCH 0005/1824] [SPARK-9558][DOCS]Update docs to follow the increase of memory defaults. Now the memory defaults of master and slave in Standalone mode and History Server is 1g, not 512m. So let's update docs. Author: Kousuke Saruta Closes #7896 from sarutak/update-doc-for-daemon-memory and squashes the following commits: a77626c [Kousuke Saruta] Fix docs to follow the update of increase of memory defaults --- conf/spark-env.sh.template | 1 + docs/monitoring.md | 2 +- docs/spark-standalone.md | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 192d3ae09113..c05fe381a36a 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -38,6 +38,7 @@ # - SPARK_WORKER_INSTANCES, to set the number of worker processes per node # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") +# - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g). # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") # - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") diff --git a/docs/monitoring.md b/docs/monitoring.md index bcf885fe4e68..cedceb295802 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -48,7 +48,7 @@ follows: Environment VariableMeaning SPARK_DAEMON_MEMORY - Memory to allocate to the history server (default: 512m). + Memory to allocate to the history server (default: 1g). SPARK_DAEMON_JAVA_OPTS diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 4f71fbc086cd..2fe9ec3542b2 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -152,7 +152,7 @@ You can optionally configure the cluster further by setting environment variable SPARK_DAEMON_MEMORY - Memory to allocate to the Spark master and worker daemons themselves (default: 512m). + Memory to allocate to the Spark master and worker daemons themselves (default: 1g). SPARK_DAEMON_JAVA_OPTS From 8ca287ebbd58985a568341b08040d0efa9d3641a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 3 Aug 2015 13:58:00 -0700 Subject: [PATCH 0006/1824] [SPARK-9191] [ML] [Doc] Add ml.PCA user guide and code examples Add ml.PCA user guide document and code examples for Scala/Java/Python. Author: Yanbo Liang Closes #7522 from yanboliang/ml-pca-md and squashes the following commits: 60dec05 [Yanbo Liang] address comments f992abe [Yanbo Liang] Add ml.PCA doc and examples --- docs/ml-features.md | 86 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 54068debe215..fa0ad1f00ab1 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -461,6 +461,92 @@ for binarized_feature, in binarizedFeatures.collect(): +## PCA + +[PCA](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical procedure that uses an orthogonal transformation to convert a set of observations of possibly correlated variables into a set of values of linearly uncorrelated variables called principal components. A [PCA](api/scala/index.html#org.apache.spark.ml.feature.PCA) class trains a model to project vectors to a low-dimensional space using PCA. The example below shows how to project 5-dimensional feature vectors into 3-dimensional principal components. + +
+
+See the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.feature.PCA) for API details. +{% highlight scala %} +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors + +val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) +) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) +val pcaDF = pca.transform(df) +val result = pcaDF.select("pcaFeatures") +result.show() +{% endhighlight %} +
+ +
+See the [Java API documentation](api/java/org/apache/spark/ml/feature/PCA.html) for API details. +{% highlight java %} +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.PCA +import org.apache.spark.ml.feature.PCAModel +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaSparkContext jsc = ... +SQLContext jsql = ... +JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); +DataFrame result = pca.transform(df).select("pcaFeatures"); +result.show(); +{% endhighlight %} +
+ +
+See the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for API details. +{% highlight python %} +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] +df = sqlContext.createDataFrame(data,["features"]) +pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") +model = pca.fit(df) +result = model.transform(df).select("pcaFeatures") +result.show(truncate=False) +{% endhighlight %} +
+
+ ## PolynomialExpansion [Polynomial expansion](http://en.wikipedia.org/wiki/Polynomial_expansion) is the process of expanding your features into a polynomial space, which is formulated by an n-degree combination of original dimensions. A [PolynomialExpansion](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) class provides this functionality. The example below shows how to expand your features into a 3-degree polynomial space. From e4765a46833baff1dd7465c4cf50e947de7e8f21 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Aug 2015 13:59:35 -0700 Subject: [PATCH 0007/1824] [SPARK-9544] [MLLIB] add Python API for RFormula Add Python API for RFormula. Similar to other feature transformers in Python. This is just a thin wrapper over the Scala implementation. ericl MechCoder Author: Xiangrui Meng Closes #7879 from mengxr/SPARK-9544 and squashes the following commits: 3d5ff03 [Xiangrui Meng] add an doctest for . and - 5e969a5 [Xiangrui Meng] fix pydoc 1cd41f8 [Xiangrui Meng] organize imports 3c18b10 [Xiangrui Meng] add Python API for RFormula --- .../apache/spark/ml/feature/RFormula.scala | 21 ++--- python/pyspark/ml/feature.py | 85 ++++++++++++++++++- 2 files changed, 91 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d1726917e451..d5360c9217ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.parsing.combinator.RegexParsers import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage} +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** @@ -63,31 +61,26 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R */ val formula: Param[String] = new Param(this, "formula", "R model formula") - private var parsedFormula: Option[ParsedRFormula] = None - /** * Sets the formula to use for this transformer. Must be called before use. * @group setParam * @param value an R formula in string form (e.g. "y ~ x + z") */ - def setFormula(value: String): this.type = { - parsedFormula = Some(RFormulaParser.parse(value)) - set(formula, value) - this - } + def setFormula(value: String): this.type = set(formula, value) /** @group getParam */ def getFormula: String = $(formula) /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { - require(parsedFormula.isDefined, "Must call setFormula() first.") - parsedFormula.get.hasIntercept + require(isDefined(formula), "Formula must be defined first.") + RFormulaParser.parse($(formula)).hasIntercept } override def fit(dataset: DataFrame): RFormulaModel = { - require(parsedFormula.isDefined, "Must call setFormula() first.") - val resolvedFormula = parsedFormula.get.resolve(dataset.schema) + require(isDefined(formula), "Formula must be defined first.") + val parsedFormula = RFormulaParser.parse($(formula)) + val resolvedFormula = parsedFormula.resolve(dataset.schema) // StringType terms and terms representing interactions need to be encoded before assembly. // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 015e7a9d4900..3f04c41ac5ab 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -24,7 +24,7 @@ __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel'] + 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] @inherit_doc @@ -1110,6 +1110,89 @@ class PCAModel(JavaModel): """ +@inherit_doc +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): + """ + .. note:: Experimental + + Implements the transforms required for fitting a dataset against an + R model formula. Currently we support a limited subset of the R + operators, including '~', '+', '-', and '.'. Also see the R formula + docs: + http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + + >>> df = sqlContext.createDataFrame([ + ... (1.0, 1.0, "a"), + ... (0.0, 2.0, "b"), + ... (0.0, 0.0, "a") + ... ], ["y", "x", "s"]) + >>> rf = RFormula(formula="y ~ x + s") + >>> rf.fit(df).transform(df).show() + +---+---+---+---------+-----+ + | y| x| s| features|label| + +---+---+---+---------+-----+ + |1.0|1.0| a|[1.0,1.0]| 1.0| + |0.0|2.0| b|[2.0,0.0]| 0.0| + |0.0|0.0| a|[0.0,1.0]| 0.0| + +---+---+---+---------+-----+ + ... + >>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show() + +---+---+---+--------+-----+ + | y| x| s|features|label| + +---+---+---+--------+-----+ + |1.0|1.0| a| [1.0]| 1.0| + |0.0|2.0| b| [2.0]| 0.0| + |0.0|0.0| a| [0.0]| 0.0| + +---+---+---+--------+-----+ + ... + """ + + # a placeholder to make it appear in the generated doc + formula = Param(Params._dummy(), "formula", "R model formula") + + @keyword_only + def __init__(self, formula=None, featuresCol="features", labelCol="label"): + """ + __init__(self, formula=None, featuresCol="features", labelCol="label") + """ + super(RFormula, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) + self.formula = Param(self, "formula", "R model formula") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, formula=None, featuresCol="features", labelCol="label"): + """ + setParams(self, formula=None, featuresCol="features", labelCol="label") + Sets params for RFormula. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setFormula(self, value): + """ + Sets the value of :py:attr:`formula`. + """ + self._paramMap[self.formula] = value + return self + + def getFormula(self): + """ + Gets the value of :py:attr:`formula`. + """ + return self.getOrDefault(self.formula) + + def _create_model(self, java_model): + return RFormulaModel(java_model) + + +class RFormulaModel(JavaModel): + """ + Model fitted by :py:class:`RFormula`. + """ + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 702aa9d7fb16c98a50e046edfd76b8a7861d0391 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 3 Aug 2015 14:22:07 -0700 Subject: [PATCH 0008/1824] [SPARK-8735] [SQL] Expose memory usage for shuffles, joins and aggregations This patch exposes the memory used by internal data structures on the SparkUI. This tracks memory used by all spilling operations and SQL operators backed by Tungsten, e.g. `BroadcastHashJoin`, `ExternalSort`, `GeneratedAggregate` etc. The metric exposed is "peak execution memory", which broadly refers to the peak in-memory sizes of each of these data structure. A separate patch will extend this by linking the new information to the SQL operators themselves. screen shot 2015-07-29 at 7 43 17 pm screen shot 2015-07-29 at 7 43 05 pm [Review on Reviewable](https://reviewable.io/reviews/apache/spark/7770) Author: Andrew Or Closes #7770 from andrewor14/expose-memory-metrics and squashes the following commits: 9abecb9 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics f5b0d68 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics d7df332 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 8eefbc5 [Andrew Or] Fix non-failing tests 9de2a12 [Andrew Or] Fix tests due to another logical merge conflict 876bfa4 [Andrew Or] Fix failing test after logical merge conflict 361a359 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 40b4802 [Andrew Or] Fix style? d0fef87 [Andrew Or] Fix tests? b3b92f6 [Andrew Or] Address comments 0625d73 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics c00a197 [Andrew Or] Fix potential NPEs 10da1cd [Andrew Or] Fix compile 17f4c2d [Andrew Or] Fix compile? a87b4d0 [Andrew Or] Fix compile? d70874d [Andrew Or] Fix test compile + address comments 2840b7d [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 6aa2f7a [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics b889a68 [Andrew Or] Minor changes: comments, spacing, style 663a303 [Andrew Or] UnsafeShuffleWriter: update peak memory before close d090a94 [Andrew Or] Fix style 2480d84 [Andrew Or] Expand test coverage 5f1235b [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 1ecf678 [Andrew Or] Minor changes: comments, style, unused imports 0b6926c [Andrew Or] Oops 111a05e [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics a7a39a5 [Andrew Or] Strengthen presence check for accumulator a919eb7 [Andrew Or] Add tests for unsafe shuffle writer 23c845d [Andrew Or] Add tests for SQL operators a757550 [Andrew Or] Address comments b5c51c1 [Andrew Or] Re-enable test in JavaAPISuite 5107691 [Andrew Or] Add tests for internal accumulators 59231e4 [Andrew Or] Fix tests 9528d09 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics 5b5e6f3 [Andrew Or] Add peak execution memory to summary table + tooltip 92b4b6b [Andrew Or] Display peak execution memory on the UI eee5437 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics d9b9015 [Andrew Or] Track execution memory in unsafe shuffles 770ee54 [Andrew Or] Track execution memory in broadcast joins 9c605a4 [Andrew Or] Track execution memory in GeneratedAggregate 9e824f2 [Andrew Or] Add back execution memory tracking for *ExternalSort 4ef4cb1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics e6c3e2f [Andrew Or] Move internal accumulators creation to Stage a417592 [Andrew Or] Expose memory metrics in UnsafeExternalSorter 3c4f042 [Andrew Or] Track memory usage in ExternalAppendOnlyMap / ExternalSorter bd7ab3f [Andrew Or] Add internal accumulators to TaskContext --- .../unsafe/UnsafeShuffleExternalSorter.java | 27 ++- .../shuffle/unsafe/UnsafeShuffleWriter.java | 38 +++- .../spark/unsafe/map/BytesToBytesMap.java | 8 +- .../unsafe/sort/UnsafeExternalSorter.java | 29 ++- .../org/apache/spark/ui/static/webui.css | 2 +- .../scala/org/apache/spark/Accumulators.scala | 60 +++++- .../scala/org/apache/spark/Aggregator.scala | 24 +-- .../scala/org/apache/spark/TaskContext.scala | 13 +- .../org/apache/spark/TaskContextImpl.scala | 8 + .../org/apache/spark/rdd/CoGroupedRDD.scala | 9 +- .../spark/scheduler/AccumulableInfo.scala | 9 +- .../apache/spark/scheduler/DAGScheduler.scala | 28 ++- .../apache/spark/scheduler/ResultTask.scala | 6 +- .../spark/scheduler/ShuffleMapTask.scala | 10 +- .../org/apache/spark/scheduler/Stage.scala | 16 ++ .../org/apache/spark/scheduler/Task.scala | 18 +- .../shuffle/hash/HashShuffleReader.scala | 8 +- .../scala/org/apache/spark/ui/ToolTips.scala | 7 + .../org/apache/spark/ui/jobs/StagePage.scala | 140 +++++++++---- .../spark/ui/jobs/TaskDetailsClassNames.scala | 1 + .../collection/ExternalAppendOnlyMap.scala | 13 +- .../util/collection/ExternalSorter.scala | 20 +- .../java/org/apache/spark/JavaAPISuite.java | 3 +- .../unsafe/UnsafeShuffleWriterSuite.java | 54 +++++ .../map/AbstractBytesToBytesMapSuite.java | 39 ++++ .../sort/UnsafeExternalSorterSuite.java | 46 +++++ .../org/apache/spark/AccumulatorSuite.scala | 193 +++++++++++++++++- .../org/apache/spark/CacheManagerSuite.scala | 10 +- .../org/apache/spark/rdd/PipedRDDSuite.scala | 2 +- .../org/apache/spark/scheduler/FakeTask.scala | 6 +- .../scheduler/NotSerializableFakeTask.scala | 2 +- .../spark/scheduler/TaskContextSuite.scala | 7 +- .../spark/scheduler/TaskSetManagerSuite.scala | 2 +- .../shuffle/hash/HashShuffleReaderSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 8 +- .../org/apache/spark/ui/StagePageSuite.scala | 76 +++++++ .../ExternalAppendOnlyMapSuite.scala | 15 ++ .../util/collection/ExternalSorterSuite.scala | 14 +- .../execution/UnsafeExternalRowSorter.java | 7 + .../UnsafeFixedWidthAggregationMap.java | 8 + .../sql/execution/GeneratedAggregate.scala | 11 +- .../execution/joins/BroadcastHashJoin.scala | 10 +- .../joins/BroadcastHashOuterJoin.scala | 8 + .../joins/BroadcastLeftSemiJoinHash.scala | 10 +- .../sql/execution/joins/HashedRelation.scala | 22 +- .../org/apache/spark/sql/execution/sort.scala | 12 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 60 ++++-- .../sql/execution/TungstenSortSuite.scala | 12 ++ .../UnsafeFixedWidthAggregationMapSuite.scala | 3 +- .../UnsafeKVExternalSorterSuite.scala | 3 +- .../execution/joins/BroadcastJoinSuite.scala | 94 +++++++++ 51 files changed, 1070 insertions(+), 163 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 1aa6ba420126..bf4eaa59ff58 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; import java.util.LinkedList; +import javax.annotation.Nullable; import scala.Tuple2; @@ -86,9 +87,12 @@ final class UnsafeShuffleExternalSorter { private final LinkedList spills = new LinkedList(); + /** Peak memory used by this sorter so far, in bytes. **/ + private long peakMemoryUsedBytes; + // These variables are reset after spilling: - private UnsafeShuffleInMemorySorter sorter; - private MemoryBlock currentPage = null; + @Nullable private UnsafeShuffleInMemorySorter sorter; + @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; @@ -106,6 +110,7 @@ public UnsafeShuffleExternalSorter( this.blockManager = blockManager; this.taskContext = taskContext; this.initialSize = initialSize; + this.peakMemoryUsedBytes = initialSize; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; @@ -279,10 +284,26 @@ private long getMemoryUsage() { for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return sorter.getMemoryUsage() + totalPageSize; + return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } private long freeMemory() { + updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { memoryManager.freePage(block); diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index d47d6fc9c2ac..6e2eeb37c86f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -27,6 +27,7 @@ import scala.collection.JavaConversions; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; +import scala.collection.immutable.Map; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; @@ -78,8 +79,9 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; - private MapStatus mapStatus = null; - private UnsafeShuffleExternalSorter sorter = null; + @Nullable private MapStatus mapStatus; + @Nullable private UnsafeShuffleExternalSorter sorter; + private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { @@ -131,9 +133,28 @@ public UnsafeShuffleWriter( @VisibleForTesting public int maxRecordSizeBytes() { + assert(sorter != null); return sorter.maxRecordSizeBytes; } + private void updatePeakMemoryUsed() { + // sorter can be null if this writer is closed + if (sorter != null) { + long mem = sorter.getPeakMemoryUsedBytes(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + /** * This convenience method should only be called in test code. */ @@ -144,7 +165,7 @@ public void write(Iterator> records) throws IOException { @Override public void write(scala.collection.Iterator> records) throws IOException { - // Keep track of success so we know if we ecountered an exception + // Keep track of success so we know if we encountered an exception // We do this rather than a standard try/catch/re-throw to handle // generic throwables. boolean success = false; @@ -189,6 +210,8 @@ private void open() throws IOException { @VisibleForTesting void closeAndWriteOutput() throws IOException { + assert(sorter != null); + updatePeakMemoryUsed(); serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); @@ -209,6 +232,7 @@ void closeAndWriteOutput() throws IOException { @VisibleForTesting void insertRecordIntoSorter(Product2 record) throws IOException { + assert(sorter != null); final K key = record._1(); final int partitionId = partitioner.getPartition(key); serBuffer.reset(); @@ -431,6 +455,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th @Override public Option stop(boolean success) { try { + // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite) + Map> internalAccumulators = + taskContext.internalMetricsToAccumulators(); + if (internalAccumulators != null) { + internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY()) + .add(getPeakMemoryUsedBytes()); + } + if (stopping) { return Option.apply(null); } else { diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 01a66084e918..20347433e16b 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -505,7 +505,7 @@ public boolean putNewKey( // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. - // (8 byte key length) (key) (8 byte value length) (value) + // (8 byte key length) (key) (value) final long requiredSize = 8 + keyLengthBytes + valueLengthBytes; // --- Figure out where to insert the new record --------------------------------------------- @@ -655,7 +655,10 @@ public long getPageSizeBytes() { return pageSizeBytes; } - /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ + /** + * Returns the total amount of memory, in bytes, consumed by this map's managed structures. + * Note that this is also the peak memory used by this map, since the map is append-only. + */ public long getTotalMemoryConsumption() { long totalDataPagesSize = 0L; for (MemoryBlock dataPage : dataPages) { @@ -674,7 +677,6 @@ public long getTimeSpentResizingNs() { return timeSpentResizingNs; } - /** * Returns the average number of probes per key lookup. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index b984301cbbf2..bf5f965a9d8d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -70,13 +70,14 @@ public final class UnsafeExternalSorter { private final LinkedList spillWriters = new LinkedList<>(); // These variables are reset after spilling: - private UnsafeInMemorySorter inMemSorter; + @Nullable private UnsafeInMemorySorter inMemSorter; // Whether the in-mem sorter is created internally, or passed in from outside. // If it is passed in from outside, we shouldn't release the in-mem sorter's memory. private boolean isInMemSorterExternal = false; private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; + private long peakMemoryUsedBytes = 0; public static UnsafeExternalSorter createWithExistingInMemorySorter( TaskMemoryManager taskMemoryManager, @@ -183,6 +184,7 @@ public void closeCurrentPage() { * Sort and spill the current records in response to memory pressure. */ public void spill() throws IOException { + assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -219,7 +221,22 @@ private long getMemoryUsage() { for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return inMemSorter.getMemoryUsage() + totalPageSize; + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } @VisibleForTesting @@ -233,6 +250,7 @@ public int getNumberOfAllocatedPages() { * @return the number of bytes freed. */ public long freeMemory() { + updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { taskMemoryManager.freePage(block); @@ -277,7 +295,8 @@ public void deleteSpillFiles() { * @return true if the record can be inserted without requiring more allocations, false otherwise. */ private boolean haveSpaceForRecord(int requiredSpace) { - assert (requiredSpace > 0); + assert(requiredSpace > 0); + assert(inMemSorter != null); return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); } @@ -290,6 +309,7 @@ private boolean haveSpaceForRecord(int requiredSpace) { * the record size. */ private void allocateSpaceForRecord(int requiredSpace) throws IOException { + assert(inMemSorter != null); // TODO: merge these steps to first calculate total memory requirements for this insert, // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the // data page. @@ -350,6 +370,7 @@ public void insertRecord( if (!haveSpaceForRecord(totalSpaceRequired)) { allocateSpaceForRecord(totalSpaceRequired); } + assert(inMemSorter != null); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); @@ -382,6 +403,7 @@ public void insertKVRecord( if (!haveSpaceForRecord(totalSpaceRequired)) { allocateSpaceForRecord(totalSpaceRequired); } + assert(inMemSorter != null); final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); @@ -405,6 +427,7 @@ public void insertKVRecord( } public UnsafeSorterIterator getSortedIterator() throws IOException { + assert(inMemSorter != null); final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator(); int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); if (spillWriters.isEmpty()) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index b1cef4704224..648cd1b10480 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -207,7 +207,7 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, -.serialization_time, .getting_result_time { +.serialization_time, .getting_result_time, .peak_execution_memory { display: none; } diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index eb75f26718e1..b6a0119c696f 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -152,8 +152,14 @@ class Accumulable[R, T] private[spark] ( in.defaultReadObject() value_ = zero deserialized = true - val taskContext = TaskContext.get() - taskContext.registerAccumulator(this) + // Automatically register the accumulator when it is deserialized with the task closure. + // Note that internal accumulators are deserialized before the TaskContext is created and + // are registered in the TaskContext constructor. + if (!isInternal) { + val taskContext = TaskContext.get() + assume(taskContext != null, "Task context was null when deserializing user accumulators") + taskContext.registerAccumulator(this) + } } override def toString: String = if (value_ == null) "null" else value_.toString @@ -248,10 +254,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T, T](initialValue, param, name) { +class Accumulator[T] private[spark] ( + @transient initialValue: T, + param: AccumulatorParam[T], + name: Option[String], + internal: Boolean) + extends Accumulable[T, T](initialValue, param, name, internal) { + + def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = { + this(initialValue, param, name, false) + } - def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) + def this(initialValue: T, param: AccumulatorParam[T]) = { + this(initialValue, param, None, false) + } } /** @@ -342,3 +358,37 @@ private[spark] object Accumulators extends Logging { } } + +private[spark] object InternalAccumulator { + val PEAK_EXECUTION_MEMORY = "peakExecutionMemory" + val TEST_ACCUMULATOR = "testAccumulator" + + // For testing only. + // This needs to be a def since we don't want to reuse the same accumulator across stages. + private def maybeTestAccumulator: Option[Accumulator[Long]] = { + if (sys.props.contains("spark.testing")) { + Some(new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true)) + } else { + None + } + } + + /** + * Accumulators for tracking internal metrics. + * + * These accumulators are created with the stage such that all tasks in the stage will + * add to the same set of accumulators. We do this to report the distribution of accumulator + * values across all tasks within each stage. + */ + def create(): Seq[Accumulator[Long]] = { + Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + } +} diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index ceeb58075d34..289aab9bd9e5 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -58,12 +58,7 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + updateMetrics(context, combiners) combiners.iterator } } @@ -89,13 +84,18 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) combiners.insertAll(iter) - // Update task metrics if context is not null - // TODO: Make context non-optional in a future release - Option(context).foreach { c => - c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled) - c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled) - } + updateMetrics(context, combiners) combiners.iterator } } + + /** Update task metrics after populating the external map. */ + private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = { + Option(context).foreach { c => + c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + c.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + } + } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 5d2c551d5851..63cca80b2d73 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -61,12 +61,12 @@ object TaskContext { protected[spark] def unset(): Unit = taskContext.remove() /** - * Return an empty task context that is not actually used. - * Internal use only. + * An empty task context that does not represent an actual task. */ - private[spark] def empty(): TaskContext = { - new TaskContextImpl(0, 0, 0, 0, null, null) + private[spark] def empty(): TaskContextImpl = { + new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty) } + } @@ -187,4 +187,9 @@ abstract class TaskContext extends Serializable { * accumulator id and the value of the Map is the latest accumulator local value. */ private[spark] def collectAccumulators(): Map[Long, Any] + + /** + * Accumulators for tracking internal metrics indexed by the name. + */ + private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]] } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 9ee168ae016f..5df94c6d3a10 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -32,6 +32,7 @@ private[spark] class TaskContextImpl( override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, @transient private val metricsSystem: MetricsSystem, + internalAccumulators: Seq[Accumulator[Long]], val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -114,4 +115,11 @@ private[spark] class TaskContextImpl( private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized { accumulators.mapValues(_.localValue).toMap } + + private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = { + // Explicitly register internal accumulators here because these are + // not captured in the task closure and are already deserialized + internalAccumulators.foreach(registerAccumulator) + internalAccumulators.map { a => (a.name.get, a) }.toMap + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 130b58882d8e..9c617fc719cb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,8 +23,7 @@ import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} -import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.util.Utils @@ -169,8 +168,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: for ((it, depNum) <- rddIterators) { map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } - context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index e0edd7d4ae96..11d123eec43c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,12 @@ import org.apache.spark.annotation.DeveloperApi * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. */ @DeveloperApi -class AccumulableInfo ( +class AccumulableInfo private[spark] ( val id: Long, val name: String, val update: Option[String], // represents a partial update within a task - val value: String) { + val value: String, + val internal: Boolean) { override def equals(other: Any): Boolean = other match { case acc: AccumulableInfo => @@ -40,10 +41,10 @@ class AccumulableInfo ( object AccumulableInfo { def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { - new AccumulableInfo(id, name, update, value) + new AccumulableInfo(id, name, update, value, internal = false) } def apply(id: Long, name: String, value: String): AccumulableInfo = { - new AccumulableInfo(id, name, None, value) + new AccumulableInfo(id, name, None, value, internal = false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c4fa277c2125..bb489c6b6e98 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -773,16 +773,26 @@ class DAGScheduler( stage.pendingTasks.clear() // First figure out the indexes of partition ids to compute. - val partitionsToCompute: Seq[Int] = { + val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { stage match { case stage: ShuffleMapStage => - (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty) + val allPartitions = 0 until stage.numPartitions + val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty } + (allPartitions, filteredPartitions) case stage: ResultStage => val job = stage.resultOfJob.get - (0 until job.numPartitions).filter(id => !job.finished(id)) + val allPartitions = 0 until job.numPartitions + val filteredPartitions = allPartitions.filter { id => !job.finished(id) } + (allPartitions, filteredPartitions) } } + // Reset internal accumulators only if this stage is not partially submitted + // Otherwise, we may override existing accumulator values from some tasks + if (allPartitions == partitionsToCompute) { + stage.resetInternalAccumulators() + } + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull runningStages += stage @@ -852,7 +862,8 @@ class DAGScheduler( partitionsToCompute.map { id => val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs) + new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, stage.internalAccumulators) } case stage: ResultStage => @@ -861,7 +872,8 @@ class DAGScheduler( val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) - new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id) + new ResultTask(stage.id, stage.latestInfo.attemptId, + taskBinary, part, locs, id, stage.internalAccumulators) } } } catch { @@ -916,9 +928,11 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}") + val value = s"${acc.value}" + stage.latestInfo.accumulables(id) = + new AccumulableInfo(id, name, None, value, acc.isInternal) event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}") + new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 9c2606e278c5..c4dc080e2b22 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -45,8 +45,10 @@ private[spark] class ResultTask[T, U]( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient locs: Seq[TaskLocation], - val outputId: Int) - extends Task[U](stageId, stageAttemptId, partition.index) with Serializable { + val outputId: Int, + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) + with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 14c8c0096148..f478f9982afe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -43,12 +43,14 @@ private[spark] class ShuffleMapTask( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging { + @transient private var locs: Seq[TaskLocation], + internalAccumulators: Seq[Accumulator[Long]]) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators) + with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -69,7 +71,7 @@ private[spark] class ShuffleMapTask( val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) - return writer.stop(success = true).get + writer.stop(success = true).get } catch { case e: Exception => try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 40a333a3e06b..de05ee256dbf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -68,6 +68,22 @@ private[spark] abstract class Stage( val name = callSite.shortForm val details = callSite.longForm + private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty + + /** Internal accumulators shared across all tasks in this stage. */ + def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators + + /** + * Re-initialize the internal accumulators associated with this stage. + * + * This is called every time the stage is submitted, *except* when a subset of tasks + * belonging to this stage has already finished. Otherwise, reinitializing the internal + * accumulators here again will override partial values from the finished tasks. + */ + def resetInternalAccumulators(): Unit = { + _internalAccumulators = InternalAccumulator.create() + } + /** * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized * here, before any attempts have actually been created, because the DAGScheduler uses this diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 1978305cfefb..9edf9f048f9f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -47,7 +47,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, - var partitionId: Int) extends Serializable { + val partitionId: Int, + internalAccumulators: Seq[Accumulator[Long]]) extends Serializable { /** * The key of the Map is the accumulator id and the value of the Map is the latest accumulator @@ -68,12 +69,13 @@ private[spark] abstract class Task[T]( metricsSystem: MetricsSystem) : (T, AccumulatorUpdates) = { context = new TaskContextImpl( - stageId = stageId, - partitionId = partitionId, - taskAttemptId = taskAttemptId, - attemptNumber = attemptNumber, - taskMemoryManager = taskMemoryManager, - metricsSystem = metricsSystem, + stageId, + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + metricsSystem, + internalAccumulators, runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index de79fa56f017..0c8f08f0f3b1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} @@ -100,8 +100,10 @@ private[spark] class HashShuffleReader[K, C]( // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) - context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) sorter.iterator case None => aggregatedIter diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index e2d25e36365f..cb122eaed83d 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -62,6 +62,13 @@ private[spark] object ToolTips { """Time that the executor spent paused for Java garbage collection while the task was running.""" + val PEAK_EXECUTION_MEMORY = + """Execution memory refers to the memory used by internal data structures created during + shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator + should be approximately the sum of the peak sizes across all such data structures created + in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and + external sort.""" + val JOB_TIMELINE = """Shows when jobs started and ended and when executors joined or left. Drag to scroll. Click Enable Zooming and use mouse wheel to zoom in/out.""" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index cf04b5e59239..3954c3d1ef89 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -26,6 +26,7 @@ import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.{InternalAccumulator, SparkConf} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.ui._ @@ -67,6 +68,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) + private val displayPeakExecutionMemory = + parent.conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean) def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { @@ -114,10 +117,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val stageData = stageDataOption.get val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) - val numCompleted = tasks.count(_.taskInfo.finished) - val accumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables - val hasAccumulators = accumulables.size > 0 + + val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables + val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } + val hasAccumulators = externalAccumulables.size > 0 val summary =
@@ -221,6 +225,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Getting Result Time + {if (displayPeakExecutionMemory) { +
  • + + + Peak Execution Memory + +
  • + }}
    @@ -241,11 +254,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - accumulables.values.toSeq) + externalAccumulables.toSeq) val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( + parent.conf, UIUtils.prependBaseUri(parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", tasks, @@ -294,12 +308,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { else { def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = Distribution(data).get.getQuantiles() - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { getDistributionQuantiles(times).map { millis => {UIUtils.formatDuration(millis.toLong)} } } + def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { + getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + } val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorDeserializeTime.toDouble @@ -349,6 +365,23 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(gettingResultTimes) + + val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => + info.accumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.value.toLong } + .getOrElse(0L) + .toDouble + } + val peakExecutionMemoryQuantiles = { + + + Peak Execution Memory + + +: getFormattedSizeQuantiles(peakExecutionMemory) + } + // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). @@ -359,10 +392,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay val schedulerDelayQuantiles = schedulerDelayTitle +: getFormattedTimeQuantiles(schedulerDelays) - - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = - getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) : Seq[Elem] = { val recordDist = getDistributionQuantiles(records).iterator @@ -466,6 +495,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {serializationQuantiles} , {gettingResultQuantiles}, + if (displayPeakExecutionMemory) { + + {peakExecutionMemoryQuantiles} + + } else { + Nil + }, if (stageData.hasInput) {inputQuantiles} else Nil, if (stageData.hasOutput) {outputQuantiles} else Nil, if (stageData.hasShuffleRead) { @@ -499,7 +535,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) val maybeAccumulableTable: Seq[Node] = - if (accumulables.size > 0) {

    Accumulators

    ++ accumulableTable } else Seq() + if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() val content = summary ++ @@ -750,29 +786,30 @@ private[ui] case class TaskTableRowBytesSpilledData( * Contains all data that needs for sorting and generating HTML. Using this one rather than * TaskUIData to avoid creating duplicate contents during sorting the data. */ -private[ui] case class TaskTableRowData( - index: Int, - taskId: Long, - attempt: Int, - speculative: Boolean, - status: String, - taskLocality: String, - executorIdAndHost: String, - launchTime: Long, - duration: Long, - formatDuration: String, - schedulerDelay: Long, - taskDeserializationTime: Long, - gcTime: Long, - serializationTime: Long, - gettingResultTime: Long, - accumulators: Option[String], // HTML - input: Option[TaskTableRowInputData], - output: Option[TaskTableRowOutputData], - shuffleRead: Option[TaskTableRowShuffleReadData], - shuffleWrite: Option[TaskTableRowShuffleWriteData], - bytesSpilled: Option[TaskTableRowBytesSpilledData], - error: String) +private[ui] class TaskTableRowData( + val index: Int, + val taskId: Long, + val attempt: Int, + val speculative: Boolean, + val status: String, + val taskLocality: String, + val executorIdAndHost: String, + val launchTime: Long, + val duration: Long, + val formatDuration: String, + val schedulerDelay: Long, + val taskDeserializationTime: Long, + val gcTime: Long, + val serializationTime: Long, + val gettingResultTime: Long, + val peakExecutionMemoryUsed: Long, + val accumulators: Option[String], // HTML + val input: Option[TaskTableRowInputData], + val output: Option[TaskTableRowOutputData], + val shuffleRead: Option[TaskTableRowShuffleReadData], + val shuffleWrite: Option[TaskTableRowShuffleWriteData], + val bytesSpilled: Option[TaskTableRowBytesSpilledData], + val error: String) private[ui] class TaskDataSource( tasks: Seq[TaskUIData], @@ -816,10 +853,15 @@ private[ui] class TaskDataSource( val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) val gettingResultTime = getGettingResultTime(info, currentTime) - val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map { acc => + val (taskInternalAccumulables, taskExternalAccumulables) = + info.accumulables.partition(_.internal) + val externalAccumulableReadable = taskExternalAccumulables.map { acc => StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") } + val peakExecutionMemoryUsed = taskInternalAccumulables + .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } + .map { acc => acc.value.toLong } + .getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) @@ -923,7 +965,7 @@ private[ui] class TaskDataSource( None } - TaskTableRowData( + new TaskTableRowData( info.index, info.taskId, info.attempt, @@ -939,7 +981,8 @@ private[ui] class TaskDataSource( gcTime, serializationTime, gettingResultTime, - if (hasAccumulators) Some(accumulatorsReadable.mkString("
    ")) else None, + peakExecutionMemoryUsed, + if (hasAccumulators) Some(externalAccumulableReadable.mkString("
    ")) else None, input, output, shuffleRead, @@ -1006,6 +1049,10 @@ private[ui] class TaskDataSource( override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) } + case "Peak Execution Memory" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed) + } case "Accumulators" => if (hasAccumulators) { new Ordering[TaskTableRowData] { @@ -1132,6 +1179,7 @@ private[ui] class TaskDataSource( } private[ui] class TaskPagedTable( + conf: SparkConf, basePath: String, data: Seq[TaskUIData], hasAccumulators: Boolean, @@ -1143,7 +1191,11 @@ private[ui] class TaskPagedTable( currentTime: Long, pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[TaskTableRowData]{ + desc: Boolean) extends PagedTable[TaskTableRowData] { + + // We only track peak memory used for unsafe operators + private val displayPeakExecutionMemory = + conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean) override def tableId: String = "" @@ -1195,6 +1247,13 @@ private[ui] class TaskPagedTable( ("GC Time", ""), ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + { + if (displayPeakExecutionMemory) { + Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) + } else { + Nil + } + } ++ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ @@ -1271,6 +1330,11 @@ private[ui] class TaskPagedTable( {UIUtils.formatDuration(task.gettingResultTime)} + {if (displayPeakExecutionMemory) { + + {Utils.bytesToString(task.peakExecutionMemoryUsed)} + + }} {if (task.accumulators.nonEmpty) { {Unparsed(task.accumulators.get)} }} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala index 9bf67db8acde..d2dfc5a32915 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -31,4 +31,5 @@ private[spark] object TaskDetailsClassNames { val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote" val RESULT_SERIALIZATION_TIME = "serialization_time" val GETTING_RESULT_TIME = "getting_result_time" + val PEAK_EXECUTION_MEMORY = "peak_execution_memory" } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index d166037351c3..f929b12606f0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,6 +89,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = @@ -97,6 +98,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ + // Peak size of the in-memory map observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes + private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() @@ -126,7 +131,11 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (maybeSpill(currentMap, currentMap.estimateSize())) { + val estimatedSize = currentMap.estimateSize() + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } + if (maybeSpill(currentMap, estimatedSize)) { currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) @@ -207,8 +216,6 @@ class ExternalAppendOnlyMap[K, V, C]( spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) } - def diskBytesSpilled: Long = _diskBytesSpilled - /** * Return an iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index ba7ec834d622..19287edbaf16 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -152,6 +152,9 @@ private[spark] class ExternalSorter[K, V, C]( private var _diskBytesSpilled = 0L def diskBytesSpilled: Long = _diskBytesSpilled + // Peak size of the in-memory data structure observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the @@ -224,15 +227,22 @@ private[spark] class ExternalSorter[K, V, C]( return } + var estimatedSize = 0L if (usingMap) { - if (maybeSpill(map, map.estimateSize())) { + estimatedSize = map.estimateSize() + if (maybeSpill(map, estimatedSize)) { map = new PartitionedAppendOnlyMap[K, C] } } else { - if (maybeSpill(buffer, buffer.estimateSize())) { + estimatedSize = buffer.estimateSize() + if (maybeSpill(buffer, estimatedSize)) { buffer = newBuffer() } } + + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } } /** @@ -684,8 +694,10 @@ private[spark] class ExternalSorter[K, V, C]( } } - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes) lengths } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e948ca33471a..ffe4b4baffb2 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -51,7 +51,6 @@ import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; -import org.apache.spark.executor.TaskMetrics; import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; import org.apache.spark.partial.PartialResult; @@ -1011,7 +1010,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics()); + TaskContext context = TaskContext$.MODULE$.empty(); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 04fc09b323db..98c32bbc298d 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -190,6 +190,7 @@ public Tuple2 answer( }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); + when(taskContext.internalMetricsToAccumulators()).thenReturn(null); when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); @@ -542,4 +543,57 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { writer.stop(false); assertSpillFilesWereCleanedUp(); } + + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + final SparkConf conf = new SparkConf().set("spark.buffer.pageSize", pageSizeBytes + "b"); + final UnsafeShuffleWriter writer = + new UnsafeShuffleWriter( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + shuffleMemoryManager, + new UnsafeShuffleHandle(0, 1, shuffleDep), + 0, // map id + taskContext, + conf); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = writer.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + writer.forceSorterToSpill(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + } + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + + // Closing the writer should not change peak memory + writer.closeAndWriteOutput(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + writer.stop(false); + } + } } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index dbb7c662d787..0e23a64fb74b 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -25,6 +25,7 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.*; import static org.mockito.AdditionalMatchers.geq; import static org.mockito.Mockito.*; @@ -495,4 +496,42 @@ public void resizingLargeMap() { map.growAndRehash(); map.free(); } + + @Test + public void testTotalMemoryConsumption() { + final long recordLengthBytes = 24; + final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker + final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes); + + // Since BytesToBytesMap is append-only, we expect the total memory consumption to be + // monotonically increasing. More specifically, every time we allocate a new page it + // should increase by exactly the size of the page. In this regard, the memory usage + // at any given time is also the peak memory used. + long previousMemory = map.getTotalMemoryConsumption(); + long newMemory; + try { + for (long i = 0; i < numRecordsPerPage * 10; i++) { + final long[] value = new long[]{i}; + map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8); + newMemory = map.getTotalMemoryConsumption(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousMemory + pageSizeBytes, newMemory); + } else { + assertEquals(previousMemory, newMemory); + } + previousMemory = newMemory; + } + } finally { + map.free(); + } + } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 52fa8bcd57e7..c11949d57a0e 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -247,4 +247,50 @@ public void testFillingPage() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void testPeakMemoryUsed() throws Exception { + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + pageSizeBytes); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = sorter.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + insertNumber(sorter, i); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // We allocated a new page for this record, so peak memory should change + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + sorter.spill(); + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + insertNumber(sorter, i); + } + newPeakMemory = sorter.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + sorter.freeMemory(); + } + } + } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index e942d6579b2f..48f549575f4d 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.ref.WeakReference import org.scalatest.Matchers +import org.scalatest.exceptions.TestFailedException +import org.apache.spark.scheduler._ -class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { + import InternalAccumulator._ implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] = new AccumulableParam[mutable.Set[A], A] { @@ -155,4 +159,191 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(!Accumulators.originals.get(accId).isDefined) } + test("internal accumulators in TaskContext") { + val accums = InternalAccumulator.create() + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) + val internalMetricsToAccums = taskContext.internalMetricsToAccumulators + val collectedInternalAccums = taskContext.collectInternalAccumulators() + val collectedAccums = taskContext.collectAccumulators() + assert(internalMetricsToAccums.size > 0) + assert(internalMetricsToAccums.values.forall(_.isInternal)) + assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR)) + val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR) + assert(collectedInternalAccums.size === internalMetricsToAccums.size) + assert(collectedInternalAccums.size === collectedAccums.size) + assert(collectedInternalAccums.contains(testAccum.id)) + assert(collectedAccums.contains(testAccum.id)) + } + + test("internal accumulators in a stage") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Have each task add 1 to the internal accumulator + sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + }.count() + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + assert(stageAccum.value.toLong === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + taskAccum.value.toLong + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + + test("internal accumulators in multiple stages") { + val listener = new SaveInfoListener + val numPartitions = 10 + sc = new SparkContext("local", "test") + sc.addSparkListener(listener) + // Each stage creates its own set of internal accumulators so the + // values for the same metric should not be mixed up across stages + sc.parallelize(1 to 100, numPartitions) + .map { i => (i, i) } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + iter + } + .reduceByKey { case (x, y) => x + y } + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10 + iter + } + .repartition(numPartitions * 2) + .mapPartitions { iter => + TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 + iter + } + .count() + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR) + val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR) + val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR) + assert(firstStageAccum.value.toLong === numPartitions) + assert(secondStageAccum.value.toLong === numPartitions * 10) + assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + } + + test("internal accumulators in fully resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks + } + + test("internal accumulators in partially resubmitted stages") { + testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset + } + + /** + * Return the accumulable info that matches the specified name. + */ + private def findAccumulableInfo( + accums: Iterable[AccumulableInfo], + name: String): AccumulableInfo = { + accums.find { a => a.name == name }.getOrElse { + throw new TestFailedException(s"internal accumulator '$name' not found", 0) + } + } + + /** + * Test whether internal accumulators are merged properly if some tasks fail. + */ + private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = { + val listener = new SaveInfoListener + val numPartitions = 10 + val numFailedPartitions = (0 until numPartitions).count(failCondition) + // This says use 1 core and retry tasks up to 2 times + sc = new SparkContext("local[1, 2]", "test") + sc.addSparkListener(listener) + sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val taskContext = TaskContext.get() + taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 + // Fail the first attempts of a subset of the tasks + if (failCondition(i) && taskContext.attemptNumber() == 0) { + throw new Exception("Failing a task intentionally.") + } + iter + }.count() + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + // We should not double count values in the merged accumulator + assert(stageAccum.value.toLong === numPartitions) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + Some(taskAccum.value.toLong) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } + } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + +} + +private[spark] object AccumulatorSuite { + + /** + * Run one or more Spark jobs and verify that the peak execution memory accumulator + * is updated afterwards. + */ + def verifyPeakExecutionMemorySet( + sc: SparkContext, + testName: String)(testBody: => Unit): Unit = { + val listener = new SaveInfoListener + sc.addSparkListener(listener) + // Verify that the accumulator does not already exist + sc.parallelize(1 to 10).count() + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) + testBody + // Verify that peak execution memory is updated + val accum = listener.getCompletedStageInfos + .flatMap(_.accumulables.values) + .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .getOrElse { + throw new TestFailedException( + s"peak execution memory accumulator not set in '$testName'", 0) + } + assert(accum.value.toLong > 0) + } +} + +/** + * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs. + */ +private class SaveInfoListener extends SparkListener { + private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] + private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] + + def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq + def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + completedStageInfos += stageCompleted.stageInfo + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + completedTaskInfos += taskEnd.taskInfo + } } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 618a5fb24710..cb8bd04e496a 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -21,7 +21,7 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.scalatest.mock.MockitoSugar -import org.apache.spark.executor.DataReadMethod +import org.apache.spark.executor.{DataReadMethod, TaskMetrics} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ @@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, null, null, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 3e8816a4c65b..5f73ec867596 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val tContext = TaskContext.empty() val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index b3ca150195a5..f7e16af9d3a9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,9 +19,11 @@ package org.apache.spark.scheduler import org.apache.spark.TaskContext -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { +class FakeTask( + stageId: Int, + prefLocs: Seq[TaskLocation] = Nil) + extends Task[Int](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Int = 0 - override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 383855caefa2..f33324792495 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0) { + extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9201d1e1f328..450ab7b9fe92 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -57,8 +57,9 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val task = new ResultTask[String, String](0, 0, - sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) intercept[RuntimeException] { task.run(0, 0, null) } @@ -66,7 +67,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0, null, null) + val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 3abb99c4b2b5..f7cc4bb61d57 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala index db718ecabbdb..05b3afef5b83 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { shuffleHandle, reduceId, reduceId + 1, - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index cf8bd8ae6962..828153bdbfc4 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -29,7 +29,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener @@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0, null, null), + TaskContext.empty(), transfer, blockManager, blocksByAddress, @@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) + val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala new file mode 100644 index 000000000000..98f9314f31df --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} +import org.apache.spark.ui.scope.RDDOperationGraphListener + +class StagePageSuite extends SparkFunSuite with LocalSparkContext { + + test("peak execution memory only displayed if unsafe is enabled") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf().set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + val targetString = "peak execution memory" + assert(html.contains(targetString)) + // Disable unsafe and make sure it's not there + val conf2 = new SparkConf().set(unsafeConf, "false") + val html2 = renderStagePage(conf2).toString().toLowerCase + assert(!html2.contains(targetString)) + } + + /** + * Render a stage page started with the given conf and return the HTML. + * This also runs a dummy stage to populate the page with useful content. + */ + private def renderStagePage(conf: SparkConf): Seq[Node] = { + val jobListener = new JobProgressListener(conf) + val graphListener = new RDDOperationGraphListener(conf) + val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS) + val request = mock(classOf[HttpServletRequest]) + when(tab.conf).thenReturn(conf) + when(tab.progressListener).thenReturn(jobListener) + when(tab.operationGraphListener).thenReturn(graphListener) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + when(request.getParameter("id")).thenReturn("0") + when(request.getParameter("attempt")).thenReturn("0") + val page = new StagePage(tab) + + // Simulate a stage in job progress listener + val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") + val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false) + jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markSuccessful() + jobListener.onTaskEnd( + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) + page.render(request) + } + +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 9c362f0de707..12e9bafcc92c 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -399,4 +399,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } + test("external aggregation updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false) + .set("spark.shuffle.memoryFraction", "0.001") + .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter + sc = new SparkContext("local", "test", conf) + // No spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") { + sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + // With spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") { + sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + } + } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 986cd8623d14..bdb0f4d507a7 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -692,7 +692,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { sortWithoutBreakingSortingContracts(createSparkConf(true, false)) } - def sortWithoutBreakingSortingContracts(conf: SparkConf) { + private def sortWithoutBreakingSortingContracts(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) @@ -743,5 +743,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } sorter2.stop() - } + } + + test("sorting updates peak execution memory") { + val conf = createSparkConf(loadDefaults = false, kryo = false) + .set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") { + sc.parallelize(1 to 1000, 2).repartition(100).count() + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 5e4c6232c947..193906d24790 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -106,6 +106,13 @@ void spill() throws IOException { sorter.spill(); } + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsage() { + return sorter.getPeakMemoryUsedBytes(); + } + private void cleanupResources() { sorter.freeMemory(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 9e2c9334a7be..43d06ce9bdfa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -208,6 +208,14 @@ public void close() { }; } + /** + * The memory used by this map's managed structures, in bytes. + * Note that this is also the peak memory used by this map, since the map is append-only. + */ + public long getMemoryUsage() { + return map.getTotalMemoryConsumption(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index cd87b8deba0c..bf4905dc1eef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import java.io.IOException -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -263,11 +263,12 @@ case class GeneratedAggregate( assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") + val taskContext = TaskContext.get() val aggregationMap = new UnsafeFixedWidthAggregationMap( newAggregationBuffer(EmptyRow), aggregationBufferSchema, groupKeySchema, - TaskContext.get.taskMemoryManager(), + taskContext.taskMemoryManager(), SparkEnv.get.shuffleMemoryManager, 1024 * 16, // initial capacity pageSizeBytes, @@ -284,6 +285,10 @@ case class GeneratedAggregate( updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) } + // Record memory used in the process + taskContext.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage) + new Iterator[InternalRow] { private[this] val mapIterator = aggregationMap.iterator() private[this] val resultProjection = resultProjectionBuilder() @@ -300,7 +305,7 @@ case class GeneratedAggregate( } else { // This is the last element in the iterator, so let's free the buffer. Before we do, // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory + // object that might contain dangling pointers to the freed memory. val resultCopy = result.copy() aggregationMap.free() resultCopy diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 624efc1b1d73..e73e2523a777 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -70,7 +71,14 @@ case class BroadcastHashJoin( val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => - hashJoin(streamedIter, broadcastRelation.value) + val hashedRelation = broadcastRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashJoin(streamedIter, hashedRelation) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 309716a0efcc..c35e439cc9de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -75,6 +76,13 @@ case class BroadcastHashOuterJoin( val hashTable = broadcastRelation.value val keyGenerator = streamedKeyGenerator + hashTable match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index a60593911f94..5bd06fbdca60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -51,7 +52,14 @@ case class BroadcastLeftSemiJoinHash( val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + val hashedRelation = broadcastedRelation.value + hashedRelation match { + case unsafe: UnsafeHashedRelation => + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) + case _ => + } + hashSemiJoin(streamIter, hashedRelation) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cc8bbfd2f894..58b4236f7b5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -183,8 +183,27 @@ private[joins] final class UnsafeHashedRelation( private[joins] def this() = this(null) // Needed for serialization // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + // This is used in broadcast joins and distributed mode only @transient private[this] var binaryMap: BytesToBytesMap = _ + /** + * Return the size of the unsafe map on the executors. + * + * For broadcast joins, this hashed relation is bigger on the driver because it is + * represented as a Java hash map there. While serializing the map to the executors, + * however, we rehash the contents in a binary map to reduce the memory footprint on + * the executors. + * + * For non-broadcast joins or in local mode, return 0. + */ + def getUnsafeSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + 0 + } + } + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] @@ -214,7 +233,7 @@ private[joins] final class UnsafeHashedRelation( } } else { - // Use the JavaHashMap in Local mode or ShuffleHashJoin + // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin) hashTable.get(unsafeKey) } } @@ -316,6 +335,7 @@ private[joins] object UnsafeHashedRelation { keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { + // Use a Java hash table here because unsafe maps expect fixed size records val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of buildKeys -> rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 92cf328c76cb..3192b6ebe907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -76,6 +77,11 @@ case class ExternalSort( val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r.copy(), null))) val baseIterator = sorter.iterator.map(_._1) + val context = TaskContext.get() + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) // TODO(marmbrus): The complex type signature below thwarts inference for no reason. CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) }, preservesPartitioning = true) @@ -137,7 +143,11 @@ case class TungstenSort( if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + val taskContext = TaskContext.get() + taskContext.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) + sortedIterator }, preservesPartitioning = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f1abae072005..29dfcf257522 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import org.scalatest.BeforeAndAfterAll +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException @@ -258,6 +259,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } } + private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + val hasGeneratedAgg = df.queryExecution.executedPlan + .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true } + .nonEmpty + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -267,26 +285,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") - def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { - val df = sql(sqlText) - // First, check if we have GeneratedAggregate. - var hasGeneratedAgg = false - df.queryExecution.executedPlan.foreach { - case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true - case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true - case _ => - } - if (!hasGeneratedAgg) { - fail( - s""" - |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. - |${df.queryExecution.simpleString} - """.stripMargin) - } - // Then, check results. - checkAnswer(df, expectedResults) - } - try { // Just to group rows. testCodeGen( @@ -1605,6 +1603,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } + test("aggregation with codegen updates peak execution memory") { + withSQLConf( + (SQLConf.CODEGEN_ENABLED.key, "true"), + (SQLConf.USE_SQL_AGGREGATE2.key, "false")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) + } + } + } + + test("external sorting updates peak execution memory") { + withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { + val sc = sqlContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") { + sortTest() + } + } + } + test("SPARK-9511: error with table starting with number") { val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) .toDF("num", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index c7949848513c..88bce0e319f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext @@ -59,6 +60,17 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) } + test("sorting updates peak execution memory") { + val sc = TestSQLContext.sparkContext + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + // Test sorting on different data types for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 7c591f6143b9..ef827b0fe9b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -69,7 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, - metricsSystem = null)) + metricsSystem = null, + internalAccumulators = Seq.empty)) try { f diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 0282b25b9dd5..601a5a07ad00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -76,7 +76,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { taskAttemptId = 98456, attemptNumber = 0, taskMemoryManager = taskMemMgr, - metricsSystem = null)) + metricsSystem = null, + internalAccumulators = Seq.empty)) // Create the data converters val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala new file mode 100644 index 000000000000..0554e11d252b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -0,0 +1,94 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +// TODO: uncomment the test here! It is currently failing due to +// bad interaction with org.apache.spark.sql.test.TestSQLContext. + +// scalastyle:off +//package org.apache.spark.sql.execution.joins +// +//import scala.reflect.ClassTag +// +//import org.scalatest.BeforeAndAfterAll +// +//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +//import org.apache.spark.sql.functions._ +//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} +// +///** +// * Test various broadcast join operators with unsafe enabled. +// * +// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs +// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode. +// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in +// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without +// * serializing the hashed relation, which does not happen in local mode. +// */ +//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { +// private var sc: SparkContext = null +// private var sqlContext: SQLContext = null +// +// /** +// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. +// */ +// override def beforeAll(): Unit = { +// super.beforeAll() +// val conf = new SparkConf() +// .setMaster("local-cluster[2,1,1024]") +// .setAppName("testing") +// sc = new SparkContext(conf) +// sqlContext = new SQLContext(sc) +// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) +// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) +// } +// +// override def afterAll(): Unit = { +// sc.stop() +// sc = null +// sqlContext = null +// } +// +// /** +// * Test whether the specified broadcast join updates the peak execution memory accumulator. +// */ +// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { +// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { +// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") +// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") +// // Comparison at the end is for broadcast left semi join +// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") +// val df3 = df1.join(broadcast(df2), joinExpression, joinType) +// val plan = df3.queryExecution.executedPlan +// assert(plan.collect { case p: T => p }.size === 1) +// plan.executeCollect() +// } +// } +// +// test("unsafe broadcast hash join updates peak execution memory") { +// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") +// } +// +// test("unsafe broadcast hash outer join updates peak execution memory") { +// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") +// } +// +// test("unsafe broadcast left semi join updates peak execution memory") { +// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") +// } +// +//} +// scalastyle:on From b2e4b85d2db0320e9cbfaf5a5542f749f1f11cf4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 3 Aug 2015 14:51:15 -0700 Subject: [PATCH 0009/1824] Revert "[SPARK-9372] [SQL] Filter nulls in join keys" This reverts commit 687c8c37150f4c93f8e57d86bb56321a4891286b. --- .../catalyst/expressions/nullFunctions.scala | 48 +--- .../sql/catalyst/optimizer/Optimizer.scala | 64 ++--- .../plans/logical/basicOperators.scala | 32 +-- .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/MathFunctionsSuite.scala | 3 +- .../expressions/NullFunctionsSuite.scala | 49 +--- .../spark/sql/DataFrameNaFunctions.scala | 2 +- .../scala/org/apache/spark/sql/SQLConf.scala | 6 - .../org/apache/spark/sql/SQLContext.scala | 5 +- .../extendedOperatorOptimizations.scala | 160 ------------ .../optimizer/FilterNullsInJoinKeySuite.scala | 236 ------------------ 11 files changed, 37 insertions(+), 572 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index d58c4756938c..287718fab7f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -210,58 +210,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { } } -/** - * A predicate that is evaluated to be true if there are at least `n` null values. - */ -case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate { - override def nullable: Boolean = false - override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})" - - private[this] val childrenArray = children.toArray - - override def eval(input: InternalRow): Boolean = { - var numNulls = 0 - var i = 0 - while (i < childrenArray.length && numNulls < n) { - val evalC = childrenArray(i).eval(input) - if (evalC == null) { - numNulls += 1 - } - i += 1 - } - numNulls >= n - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val numNulls = ctx.freshName("numNulls") - val code = children.map { e => - val eval = e.gen(ctx) - s""" - if ($numNulls < $n) { - ${eval.code} - if (${eval.isNull}) { - $numNulls += 1; - } - } - """ - }.mkString("\n") - s""" - int $numNulls = 0; - $code - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = $numNulls >= $n; - """ - } -} /** * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values. */ -case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate { +case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate { override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) - override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})" + override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})" private[this] val childrenArray = children.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e4b6294dc7b8..29d706dcb39a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,14 +31,8 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] -class DefaultOptimizer extends Optimizer { - - /** - * Override to provide additional rules for the "Operator Optimizations" batch. - */ - val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil - - lazy val batches = +object DefaultOptimizer extends Optimizer { + val batches = // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: @@ -47,27 +41,26 @@ class DefaultOptimizer extends Optimizer { RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - SetOperationPushDown :: - SamplePushDown :: - PushPredicateThroughJoin :: - PushPredicateThroughProject :: - PushPredicateThroughGenerate :: - ColumnPruning :: + SetOperationPushDown, + SamplePushDown, + PushPredicateThroughJoin, + PushPredicateThroughProject, + PushPredicateThroughGenerate, + ColumnPruning, // Operator combine - ProjectCollapsing :: - CombineFilters :: - CombineLimits :: + ProjectCollapsing, + CombineFilters, + CombineLimits, // Constant folding - NullPropagation :: - OptimizeIn :: - ConstantFolding :: - LikeSimplification :: - BooleanSimplification :: - RemovePositive :: - SimplifyFilters :: - SimplifyCasts :: - SimplifyCaseConversionExpressions :: - extendedOperatorOptimizationRules.toList : _*) :: + NullPropagation, + OptimizeIn, + ConstantFolding, + LikeSimplification, + BooleanSimplification, + RemovePositive, + SimplifyFilters, + SimplifyCasts, + SimplifyCaseConversionExpressions) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -229,18 +222,12 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** Applies a projection only when the child is producing unnecessary attributes */ - private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = { + private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) { - // We need to preserve the nullability of c's output. - // So, we first create a outputMap and if a reference is from the output of - // c, we use that output attribute from c. - val outputMap = AttributeMap(c.output.map(attr => (attr, attr))) - val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq - Project(projectList, c) + Project(allReferences.filter(c.outputSet.contains).toSeq, c) } else { c } - } } /** @@ -530,13 +517,6 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { */ object CombineFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Filter(Not(AtLeastNNulls(1, e1)), Filter(Not(AtLeastNNulls(1, e2)), grandChild)) => - // If we are combining two expressions Not(AtLeastNNulls(1, e1)) and - // Not(AtLeastNNulls(1, e2)) - // (this is used to make sure there is no null in the result of e1 and e2 and - // they are added by FilterNullsInJoinKey optimziation rule), we can - // just create a Not(AtLeastNNulls(1, (e1 ++ e2).distinct)). - Filter(Not(AtLeastNNulls(1, (e1 ++ e2).distinct)), grandChild) case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 54b5f4977266..aacfc86ab0e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -86,37 +86,7 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - /** - * Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children - * have at least one null value and atLeastNNulls.children are all attributes. - */ - private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = { - val expressions = atLeastNNulls.children - val n = atLeastNNulls.n - if (n != 1) { - // AtLeastNNulls is not used to check if atLeastNNulls.children have - // at least one null value. - false - } else { - // AtLeastNNulls is used to check if atLeastNNulls.children have - // at least one null value. We need to make sure all atLeastNNulls.children - // are attributes. - expressions.forall(_.isInstanceOf[Attribute]) - } - } - - override def output: Seq[Attribute] = condition match { - case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) => - // The condition is used to make sure that there is no null value in - // a.children. - val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]]) - child.output.map { - case attr if nonNullableAttributes.contains(attr) => - attr.withNullability(false) - case attr => attr - } - case _ => child.output - } + override def output: Seq[Attribute] = child.output } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3e5515129874..a41185b4d875 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,8 +31,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} trait ExpressionEvalHelper { self: SparkFunSuite => - protected val defaultOptimizer = new DefaultOptimizer - protected def create_row(values: Any*): InternalRow = { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } @@ -188,7 +186,7 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = defaultOptimizer.execute(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 649a5b44dc03..9fcb548af6bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -148,7 +149,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = defaultOptimizer.execute(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index bf197124d8db..ace6c15dc841 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -77,7 +77,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("AtLeastNNonNullNans") { + test("AtLeastNNonNulls") { val mix = Seq(Literal("x"), Literal.create(null, StringType), Literal.create(null, DoubleType), @@ -96,46 +96,11 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(Float.MaxValue), Literal(false)) - checkEvaluation(AtLeastNNonNullNans(0, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(0, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(4, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(0, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(3, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNonNullNans(4, nullOnly), false, EmptyRow) - } - - test("AtLeastNNull") { - val mix = Seq(Literal("x"), - Literal.create(null, StringType), - Literal.create(null, DoubleType), - Literal(Double.NaN), - Literal(5f)) - - val nanOnly = Seq(Literal("x"), - Literal(10.0), - Literal(Float.NaN), - Literal(math.log(-2)), - Literal(Double.MaxValue)) - - val nullOnly = Seq(Literal("x"), - Literal.create(null, DoubleType), - Literal.create(null, DecimalType.USER_DEFAULT), - Literal(Float.MaxValue), - Literal(false)) - - checkEvaluation(AtLeastNNulls(0, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(2, mix), true, EmptyRow) - checkEvaluation(AtLeastNNulls(3, mix), false, EmptyRow) - checkEvaluation(AtLeastNNulls(0, nanOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNulls(2, nanOnly), false, EmptyRow) - checkEvaluation(AtLeastNNulls(0, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(1, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(2, nullOnly), true, EmptyRow) - checkEvaluation(AtLeastNNulls(3, nullOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index ea85f0657a72..a4fd4cf3b330 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -122,7 +122,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. - val predicate = AtLeastNNonNullNans(minNonNulls, cols.map(name => df.resolve(name))) + val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 41ba1c7fe057..f836122b3e0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -413,10 +413,6 @@ private[spark] object SQLConf { "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) - val ADVANCED_SQL_OPTIMIZATION = booleanConf( - "spark.sql.advancedOptimization", - defaultValue = Some(true), isPublic = false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -488,8 +484,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - private[spark] def advancedSqlOptimizations: Boolean = getConf(ADVANCED_SQL_OPTIMIZATION) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 31e2b508d485..dbb2a0984654 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,7 +41,6 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.optimizer.FilterNullsInJoinKey import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -157,9 +156,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } @transient - protected[sql] lazy val optimizer: Optimizer = new DefaultOptimizer { - override val extendedOperatorOptimizationRules = FilterNullsInJoinKey(self) :: Nil - } + protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala b/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala deleted file mode 100644 index 5a4dde575696..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/optimizer/extendedOperatorOptimizations.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.optimizer - -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical.{Project, Filter, Join, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * An optimization rule used to insert Filters to filter out rows whose equal join keys - * have at least one null values. For this kind of rows, they will not contribute to - * the join results of equal joins because a null does not equal another null. We can - * filter them out before shuffling join input rows. For example, we have two tables - * - * table1(key String, value Int) - * "str1"|1 - * null |2 - * - * table2(key String, value Int) - * "str1"|3 - * null |4 - * - * For a inner equal join, the result will be - * "str1"|1|"str1"|3 - * - * those two rows having null as the value of key will not contribute to the result. - * So, we can filter them out early. - * - * This optimization rule can be disabled by setting spark.sql.advancedOptimization to false. - * - */ -case class FilterNullsInJoinKey( - sqlContext: SQLContext) - extends Rule[LogicalPlan] { - - /** - * Checks if we need to add a Filter operator. We will add a Filter when - * there is any attribute in `keys` whose corresponding attribute of `keys` - * in `plan.output` is still nullable (`nullable` field is `true`). - */ - private def needsFilter(keys: Seq[Expression], plan: LogicalPlan): Boolean = { - val keyAttributeSet = AttributeSet(keys.filter(_.isInstanceOf[Attribute])) - plan.output.filter(keyAttributeSet.contains).exists(_.nullable) - } - - /** - * Adds a Filter operator to make sure that every attribute in `keys` is non-nullable. - */ - private def addFilterIfNecessary( - keys: Seq[Expression], - child: LogicalPlan): LogicalPlan = { - // We get all attributes from keys. - val attributes = keys.filter(_.isInstanceOf[Attribute]) - - // Then, we create a Filter to make sure these attributes are non-nullable. - val filter = - if (attributes.nonEmpty) { - Filter(Not(AtLeastNNulls(1, attributes)), child) - } else { - child - } - - filter - } - - /** - * We reconstruct the join condition. - */ - private def reconstructJoinCondition( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - otherPredicate: Option[Expression]): Expression = { - // First, we rewrite the equal condition part. When we extract those keys, - // we use splitConjunctivePredicates. So, it is safe to use .reduce(And). - val rewrittenEqualJoinCondition = leftKeys.zip(rightKeys).map { - case (l, r) => EqualTo(l, r) - }.reduce(And) - - // Then, we add otherPredicate. When we extract those equal condition part, - // we use splitConjunctivePredicates. So, it is safe to use - // And(rewrittenEqualJoinCondition, c). - val rewrittenJoinCondition = otherPredicate - .map(c => And(rewrittenEqualJoinCondition, c)) - .getOrElse(rewrittenEqualJoinCondition) - - rewrittenJoinCondition - } - - def apply(plan: LogicalPlan): LogicalPlan = { - if (!sqlContext.conf.advancedSqlOptimizations) { - plan - } else { - plan transform { - case join: Join => join match { - // For a inner join having equal join condition part, we can add filters - // to both sides of the join operator. - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, withRightFilter, Inner, Some(rewrittenJoinCondition)) - - // For a left outer join having equal join condition part, we can add a filter - // to the right side of the join operator. - case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) - if needsFilter(rightKeys, right) => - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(left, withRightFilter, LeftOuter, Some(rewrittenJoinCondition)) - - // For a right outer join having equal join condition part, we can add a filter - // to the left side of the join operator. - case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, right, RightOuter, Some(rewrittenJoinCondition)) - - // For a left semi join having equal join condition part, we can add filters - // to both sides of the join operator. - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) - if needsFilter(leftKeys, left) || needsFilter(rightKeys, right) => - val withLeftFilter = addFilterIfNecessary(leftKeys, left) - val withRightFilter = addFilterIfNecessary(rightKeys, right) - val rewrittenJoinCondition = - reconstructJoinCondition(leftKeys, rightKeys, condition) - - Join(withLeftFilter, withRightFilter, LeftSemi, Some(rewrittenJoinCondition)) - - case other => other - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala deleted file mode 100644 index f98e4acafbf2..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/optimizer/FilterNullsInJoinKeySuite.scala +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.optimizer - -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Not, AtLeastNNulls} -import org.apache.spark.sql.catalyst.optimizer._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.test.TestSQLContext - -/** This is the test suite for FilterNullsInJoinKey optimization rule. */ -class FilterNullsInJoinKeySuite extends PlanTest { - - // We add predicate pushdown rules at here to make sure we do not - // create redundant Filter operators. Also, because the attribute ordering of - // the Project operator added by ColumnPruning may be not deterministic - // (the ordering may depend on the testing environment), - // we first construct the plan with expected Filter operators and then - // run the optimizer to add the the Project for column pruning. - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubQueries) :: - Batch("Operator Optimizations", FixedPoint(100), - FilterNullsInJoinKey(TestSQLContext), // This is the rule we test in this suite. - CombineFilters, - PushPredicateThroughProject, - BooleanSimplification, - PushPredicateThroughJoin, - PushPredicateThroughGenerate, - ColumnPruning, - ProjectCollapsing) :: Nil - } - - val leftRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int) - - val rightRelation = LocalRelation('e.int, 'f.int, 'g.int, 'h.int) - - test("inner join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For an inner join, FilterNullsInJoinKey add filter to both side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(correctRight, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("make sure we do not keep adding filters") { - val thirdRelation = LocalRelation('i.int, 'j.int, 'k.int, 'l.int) - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some('a === 'e)) - .join(thirdRelation, Inner, Some('b === 'i && 'a === 'j)) - - val optimized = Optimize.execute(joinedPlan.analyze) - val conditions = optimized.collect { - case Filter(condition @ Not(AtLeastNNulls(1, exprs)), _) => exprs - } - - // Make sure that we have three Not(AtLeastNNulls(1, exprs)) for those three tables. - assert(conditions.length === 3) - - // Make sure attribtues are indeed a, b, e, i, and j. - assert( - conditions.flatMap(exprs => exprs).toSet === - joinedPlan.select('a, 'b, 'e, 'i, 'j).analyze.output.toSet) - } - - test("inner join (partially optimized)") { - val joinCondition = - ('a + 2 === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // We cannot extract attribute from the left join key. - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - leftRelation - .join(correctRight, Inner, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("inner join (not optimized)") { - val nonOptimizedJoinConditions = - Some('c - 100 + 'd === 'g + 1 - 'h) :: - Some('d > 'h || 'c === 'g) :: - Some('d + 'g + 'c > 'd - 'h) :: Nil - - nonOptimizedJoinConditions.foreach { joinCondition => - val joinedPlan = - leftRelation - .join(rightRelation.select('f, 'g, 'h), Inner, joinCondition) - .select('a, 'c, 'f, 'd, 'h, 'g) - - val optimized = Optimize.execute(joinedPlan.analyze) - - comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) - } - } - - test("left outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, LeftOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a left outer join, FilterNullsInJoinKey add filter to the right side. - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - leftRelation - .join(correctRight, LeftOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("right outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, RightOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a right outer join, FilterNullsInJoinKey add filter to the left side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(rightRelation, RightOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } - - test("full outer join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, FullOuter, Some(joinCondition)) - .select('a, 'f, 'd, 'h) - - // FilterNullsInJoinKey does not fire for a full outer join. - val optimized = Optimize.execute(joinedPlan.analyze) - - comparePlans(optimized, Optimize.execute(joinedPlan.analyze)) - } - - test("left semi join") { - val joinCondition = - ('a === 'e && 'b + 1 === 'f) && ('d > 'h || 'd === 'g) - - val joinedPlan = - leftRelation - .join(rightRelation, LeftSemi, Some(joinCondition)) - .select('a, 'd) - - val optimized = Optimize.execute(joinedPlan.analyze) - - // For a left semi join, FilterNullsInJoinKey add filter to both side. - val correctLeft = - leftRelation - .where(!(AtLeastNNulls(1, 'a.expr :: Nil))) - - val correctRight = - rightRelation.where(!(AtLeastNNulls(1, 'e.expr :: 'f.expr :: Nil))) - - val correctAnswer = - correctLeft - .join(correctRight, LeftSemi, Some(joinCondition)) - .select('a, 'd) - - comparePlans(optimized, Optimize.execute(correctAnswer.analyze)) - } -} From a2409d1c8e8ddec04b529ac6f6a12b5993f0eeda Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Mon, 3 Aug 2015 15:24:34 -0700 Subject: [PATCH 0010/1824] [SPARK-8064] [SQL] Build against Hive 1.2.1 Cherry picked the parts of the initial SPARK-8064 WiP branch needed to get sql/hive to compile against hive 1.2.1. That's the ASF release packaged under org.apache.hive, not any fork. Tests not run yet: that's what the machines are for Author: Steve Loughran Author: Cheng Lian Author: Michael Armbrust Author: Patrick Wendell Closes #7191 from steveloughran/stevel/feature/SPARK-8064-hive-1.2-002 and squashes the following commits: 7556d85 [Cheng Lian] Updates .q files and corresponding golden files ef4af62 [Steve Loughran] Merge commit '6a92bb09f46a04d6cd8c41bdba3ecb727ebb9030' into stevel/feature/SPARK-8064-hive-1.2-002 6a92bb0 [Cheng Lian] Overrides HiveConf time vars dcbb391 [Cheng Lian] Adds com.twitter:parquet-hadoop-bundle:1.6.0 for Hive Parquet SerDe 0bbe475 [Steve Loughran] SPARK-8064 scalastyle rejects the standard Hadoop ASF license header... fdf759b [Steve Loughran] SPARK-8064 classpath dependency suite to be in sync with shading in final (?) hive-exec spark 7a6c727 [Steve Loughran] SPARK-8064 switch to second staging repo of the spark-hive artifacts. This one has the protobuf-shaded hive-exec jar 376c003 [Steve Loughran] SPARK-8064 purge duplicate protobuf declaration 2c74697 [Steve Loughran] SPARK-8064 switch to the protobuf shaded hive-exec jar with tests to chase it down cc44020 [Steve Loughran] SPARK-8064 remove hadoop.version from runtest.py, as profile will fix that automatically. 6901fa9 [Steve Loughran] SPARK-8064 explicit protobuf import da310dc [Michael Armbrust] Fixes for Hive tests. a775a75 [Steve Loughran] SPARK-8064 cherry-pick-incomplete 7404f34 [Patrick Wendell] Add spark-hive staging repo 832c164 [Steve Loughran] SPARK-8064 try to supress compiler warnings on Complex.java pasted-thrift-code 312c0d4 [Steve Loughran] SPARK-8064 maven/ivy dependency purge; calcite declaration needed fa5ae7b [Steve Loughran] HIVE-8064 fix up hive-thriftserver dependencies and cut back on evicted references in the hive- packages; this keeps mvn and ivy resolution compatible, as the reconciliation policy is "by hand" c188048 [Steve Loughran] SPARK-8064 manage the Hive depencencies to that -things that aren't needed are excluded -sql/hive built with ivy is in sync with the maven reconciliation policy, rather than latest-first 4c8be8d [Cheng Lian] WIP: Partial fix for Thrift server and CLI tests 314eb3c [Steve Loughran] SPARK-8064 deprecation warning noise in one of the tests 17b0341 [Steve Loughran] SPARK-8064 IDE-hinted cleanups of Complex.java to reduce compiler warnings. It's all autogenerated code, so still ugly. d029b92 [Steve Loughran] SPARK-8064 rely on unescaping to have already taken place, so go straight to map of serde options 23eca7e [Steve Loughran] HIVE-8064 handle raw and escaped property tokens 54d9b06 [Steve Loughran] SPARK-8064 fix compilation regression surfacing from rebase 0b12d5f [Steve Loughran] HIVE-8064 use subset of hive complex type whose types deserialize fce73b6 [Steve Loughran] SPARK-8064 poms rely implicitly on the version of kryo chill provides fd3aa5d [Steve Loughran] SPARK-8064 version of hive to d/l from ivy is 1.2.1 dc73ece [Steve Loughran] SPARK-8064 revert to master's determinstic pushdown strategy d3c1e4a [Steve Loughran] SPARK-8064 purge UnionType 051cc21 [Steve Loughran] SPARK-8064 switch to an unshaded version of hive-exec-core, which must have been built with Kryo 2.21. This currently looks for a (locally built) version 1.2.1.spark 6684c60 [Steve Loughran] SPARK-8064 ignore RTE raised in blocking process.exitValue() call e6121e5 [Steve Loughran] SPARK-8064 address review comments aa43dc6 [Steve Loughran] SPARK-8064 more robust teardown on JavaMetastoreDatasourcesSuite f2bff01 [Steve Loughran] SPARK-8064 better takeup of asynchronously caught error text 8b1ef38 [Steve Loughran] SPARK-8064: on failures executing spark-submit in HiveSparkSubmitSuite, print command line and all logged output. 5a9ce6b [Steve Loughran] SPARK-8064 add explicit reason for kv split failure, rather than array OOB. *does not address the issue* 642b63a [Steve Loughran] SPARK-8064 reinstate something cut briefly during rebasing 97194dc [Steve Loughran] SPARK-8064 add extra logging to the YarnClusterSuite classpath test. There should be no reason why this is failing on jenkins, but as it is (and presumably its CP-related), improve the logging including any exception raised. 335357f [Steve Loughran] SPARK-8064 fail fast on thrive process spawning tests on exit codes and/or error string patterns seen in log. 3ed872f [Steve Loughran] SPARK-8064 rename field double to dbl bca55e5 [Steve Loughran] SPARK-8064 missed one of the `date` escapes 41d6479 [Steve Loughran] SPARK-8064 wrap tests with withTable() calls to avoid table-exists exceptions 2bc29a4 [Steve Loughran] SPARK-8064 ParquetSuites to escape `date` field name 1ab9bc4 [Steve Loughran] SPARK-8064 TestHive to use sered2.thrift.test.Complex bf3a249 [Steve Loughran] SPARK-8064: more resubmit than fix; tighten startup timeout to 60s. Still no obvious reason why jersey server code in spark-assembly isn't being picked up -it hasn't been shaded c829b8f [Steve Loughran] SPARK-8064: reinstate yarn-rm-server dependencies to hive-exec to ensure that jersey server is on classpath on hadoop versions < 2.6 0b0f738 [Steve Loughran] SPARK-8064: thrift server startup to fail fast on any exception in the main thread 13abaf1 [Steve Loughran] SPARK-8064 Hive compatibilty tests sin sync with explain/show output from Hive 1.2.1 d14d5ea [Steve Loughran] SPARK-8064: DATE is now a predicate; you can't use it as a field in select ops 26eef1c [Steve Loughran] SPARK-8064: HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT 3d64523 [Steve Loughran] SPARK-8064 improve diagns on uknown token; fix scalastyle failure d0360f6 [Steve Loughran] SPARK-8064: delicate merge in of the branch vanzin/hive-1.1 1126e5a [Steve Loughran] SPARK-8064: name of unrecognized file format wasn't appearing in error text 8cb09c4 [Steve Loughran] SPARK-8064: test resilience/assertion improvements. Independent of the rest of the work; can be backported to earlier versions dec12cb [Steve Loughran] SPARK-8064: when a CLI suite test fails include the full output text in the raised exception; this ensures that the stdout/stderr is included in jenkins reports, so it becomes possible to diagnose the cause. 463a670 [Steve Loughran] SPARK-8064 run-tests.py adds a hadoop-2.6 profile, and changes info messages to say "w/Hive 1.2.1" in console output 2531099 [Steve Loughran] SPARK-8064 successful attempt to get rid of pentaho as a transitive dependency of hive-exec 1d59100 [Steve Loughran] SPARK-8064 (unsuccessful) attempt to get rid of pentaho as a transitive dependency of hive-exec 75733fc [Steve Loughran] SPARK-8064 change thrift binary startup message to "Starting ThriftBinaryCLIService on port" 3ebc279 [Steve Loughran] SPARK-8064 move strings used to check for http/bin thrift services up into constants c80979d [Steve Loughran] SPARK-8064: SparkSQLCLIDriver drops remote mode support. CLISuite Tests pass instead of timing out: undetected regression? 27e8370 [Steve Loughran] SPARK-8064 fix some style & IDE warnings 00e50d6 [Steve Loughran] SPARK-8064 stop excluding hive shims from dependency (commented out , for now) cb4f142 [Steve Loughran] SPARK-8054 cut pentaho dependency from calcite f7aa9cb [Steve Loughran] SPARK-8064 everything compiles with some commenting and moving of classes into a hive package 6c310b4 [Steve Loughran] SPARK-8064 subclass Hive ServerOptionsProcessor to make it public again f61a675 [Steve Loughran] SPARK-8064 thrift server switched to Hive 1.2.1, though it doesn't compile everywhere 4890b9d [Steve Loughran] SPARK-8064, build against Hive 1.2.1 --- core/pom.xml | 20 - dev/run-tests.py | 7 +- pom.xml | 654 +++++++++- sbin/spark-daemon.sh | 2 +- sql/catalyst/pom.xml | 1 - .../parquet/ParquetCompatibilityTest.scala | 13 +- sql/hive-thriftserver/pom.xml | 22 +- .../HiveServerServerOptionsProcessor.scala | 37 + .../hive/thriftserver/HiveThriftServer2.scala | 27 +- .../SparkExecuteStatementOperation.scala | 9 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 56 +- .../thriftserver/SparkSQLCLIService.scala | 13 +- .../thriftserver/SparkSQLSessionManager.scala | 11 +- .../sql/hive/thriftserver/CliSuite.scala | 75 +- .../HiveThriftServer2Suites.scala | 40 +- .../execution/HiveCompatibilitySuite.scala | 29 +- sql/hive/pom.xml | 92 +- .../apache/spark/sql/hive/HiveContext.scala | 114 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 +- .../org/apache/spark/sql/hive/HiveQl.scala | 97 +- .../org/apache/spark/sql/hive/HiveShim.scala | 15 +- .../sql/hive/client/ClientInterface.scala | 4 + .../spark/sql/hive/client/ClientWrapper.scala | 5 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../hive/client/IsolatedClientLoader.scala | 2 +- .../spark/sql/hive/client/package.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../hive/execution/ScriptTransformation.scala | 6 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 2 +- .../spark/sql/hive/hiveWriterContainers.scala | 2 +- .../spark/sql/hive/orc/OrcFilters.scala | 6 +- .../apache/spark/sql/hive/test/TestHive.scala | 36 +- .../apache/spark/sql/hive/test/Complex.java | 1139 +++++++++++++++++ .../hive/JavaMetastoreDataSourcesSuite.java | 6 +- ...perator-0-ee7f6a60a9792041b85b18cda56429bf | 1 + ..._string-1-db089ff46f9826c7883198adacdfad59 | 6 +- ...tar_by-5-41d474f5e6d7c61c36f74b4bec4e9e44} | 0 ...e_alter-3-2a91d52719cf4552ebeb867204552a26 | 2 +- ...b_table-4-b585371b624cbab2616a49f553a870a0 | 2 +- ...limited-1-2a91d52719cf4552ebeb867204552a26 | 2 +- ...e_serde-1-2a91d52719cf4552ebeb867204552a26 | 2 +- ...nctions-0-45a7762c39f1b0f26f076220e2764043 | 21 + ...perties-1-be4adb893c7f946ebd76a648ce3cc1ae | 2 +- ...date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 | 4 +- ...ate_sub-1-7efeb74367835ade71e5e42b22f8ced4 | 4 +- ...atediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 | 2 +- ...udf_day-0-c4c503756384ff1220222d84fd25e756 | 2 +- .../udf_day-1-87168babe1110fe4c38269843414ca4 | 11 +- ...ofmonth-0-7b2caf942528656555cf19c261a18502 | 2 +- ...ofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 | 11 +- .../udf_if-0-b7ffa85b5785cccef2af1b285348cc2c | 2 +- .../udf_if-1-30cf7f51f92b5684e556deff3032d49a | 2 +- .../udf_if-1-b7ffa85b5785cccef2af1b285348cc2c | 2 +- .../udf_if-2-30cf7f51f92b5684e556deff3032d49a | 2 +- ..._minute-0-9a38997c1f41f4afe00faa0abc471aee | 2 +- ..._minute-1-16995573ac4f4a1b047ad6ee88699e48 | 8 +- ...f_month-0-9a38997c1f41f4afe00faa0abc471aee | 2 +- ...f_month-1-16995573ac4f4a1b047ad6ee88699e48 | 8 +- ...udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 | 2 +- ..._stddev-1-18e1d598820013453fad45852e1a303d | 2 +- ...union3-0-99620f72f0282904846a596ca5b3e46c} | 0 ...union3-2-90ca96ea59fd45cf0af8c020ae77c908} | 0 ...union3-3-72b149ccaef751bcfe55d5ca37cb5fd7} | 0 .../clientpositive/parenthesis_star_by.q | 2 +- .../src/test/queries/clientpositive/union3.q | 11 +- .../sql/hive/ClasspathDependenciesSuite.scala | 110 ++ .../spark/sql/hive/HiveSparkSubmitSuite.scala | 29 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 7 +- .../hive/ParquetHiveCompatibilitySuite.scala | 9 + .../spark/sql/hive/StatisticsSuite.scala | 3 + .../spark/sql/hive/client/VersionsSuite.scala | 6 +- .../sql/hive/execution/HiveQuerySuite.scala | 89 +- .../sql/hive/execution/PruningSuite.scala | 8 +- .../sql/hive/execution/SQLQuerySuite.scala | 140 +- .../hive/orc/OrcHadoopFsRelationSuite.scala | 8 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 3 +- .../apache/spark/sql/hive/parquetSuites.scala | 327 ++--- yarn/pom.xml | 10 - .../spark/deploy/yarn/YarnClusterSuite.scala | 24 +- 79 files changed, 2861 insertions(+), 584 deletions(-) create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java create mode 100644 sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf rename sql/hive/src/test/resources/golden/{parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 => parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44} (100%) rename sql/hive/src/test/resources/golden/{union3-0-6a8a35102de1b0b88c6721a704eb174d => union3-0-99620f72f0282904846a596ca5b3e46c} (100%) rename sql/hive/src/test/resources/golden/{union3-2-2a1dcd937f117f1955a169592b96d5f9 => union3-2-90ca96ea59fd45cf0af8c020ae77c908} (100%) rename sql/hive/src/test/resources/golden/{union3-3-8fc63f8edb2969a63cd4485f1867ba97 => union3-3-72b149ccaef751bcfe55d5ca37cb5fd7} (100%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 202678779150..0e53a79fd223 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -46,30 +46,10 @@ com.twitter chill_${scala.binary.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - com.twitter chill-java - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - org.apache.hadoop diff --git a/dev/run-tests.py b/dev/run-tests.py index b6d181418f02..d1852b95bb29 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -273,6 +273,7 @@ def get_hadoop_profiles(hadoop_version): "hadoop2.0": ["-Phadoop-1", "-Dhadoop.version=2.0.0-mr1-cdh4.1.1"], "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], + "hadoop2.6": ["-Pyarn", "-Phadoop-2.6"], } if hadoop_version in sbt_maven_hadoop_profiles: @@ -289,7 +290,7 @@ def build_spark_maven(hadoop_version): mvn_goals = ["clean", "package", "-DskipTests"] profiles_and_goals = build_profiles + mvn_goals - print("[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments: ", + print("[info] Building Spark (w/Hive 1.2.1) using Maven with these arguments: ", " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) @@ -305,14 +306,14 @@ def build_spark_sbt(hadoop_version): "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals - print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", + print("[info] Building Spark (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) def build_apache_spark(build_tool, hadoop_version): - """Will build Spark against Hive v0.13.1 given the passed in build tool (either `sbt` or + """Will build Spark against Hive v1.2.1 given the passed in build tool (either `sbt` or `maven`). Defaults to using `sbt`.""" set_title_and_block("Building Spark", "BLOCK_BUILD") diff --git a/pom.xml b/pom.xml index be0dac953abf..a958cec867ea 100644 --- a/pom.xml +++ b/pom.xml @@ -134,11 +134,12 @@ 2.4.0 org.spark-project.hive - 0.13.1a + 1.2.1.spark - 0.13.1 + 1.2.1 10.10.1.1 1.7.0 + 1.6.0 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 @@ -151,7 +152,10 @@ 0.7.1 1.9.16 1.2.1 + 4.3.2 + + 3.1 3.4.1 2.10.4 2.10 @@ -161,6 +165,23 @@ 2.4.4 1.1.1.7 1.1.2 + 1.2.0-incubating + 1.10 + + 2.6 + + 3.3.2 + 3.2.10 + 2.7.8 + 1.9 + 2.5 + 3.5.2 + 1.3.9 + 0.9.2 + + + false + ${java.home} spring-releases Spring Release Repository https://repo.spring.io/libs-release - true + false false @@ -402,12 +431,17 @@ org.apache.commons commons-lang3 - 3.3.2 + ${commons-lang3.version} + + + org.apache.commons + commons-lang + ${commons-lang2.version} commons-codec commons-codec - 1.10 + ${commons-codec.version} org.apache.commons @@ -422,7 +456,12 @@ com.google.code.findbugs jsr305 - 1.3.9 + ${jsr305.version} + + + commons-httpclient + commons-httpclient + ${httpclient.classic.version} org.apache.httpcomponents @@ -439,6 +478,16 @@ selenium-java 2.42.2 test + + + com.google.guava + guava + + + io.netty + netty + + @@ -624,15 +673,26 @@ com.sun.jersey jersey-server - 1.9 + ${jersey.version} ${hadoop.deps.scope} com.sun.jersey jersey-core - 1.9 + ${jersey.version} ${hadoop.deps.scope} + + com.sun.jersey + jersey-json + ${jersey.version} + + + stax + stax-api + + + org.scala-lang scala-compiler @@ -1022,58 +1082,499 @@ hive-beeline ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-jdbc + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.thrift + libthrift + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + ${hive.group} hive-cli ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-jdbc + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.thrift + libthrift + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + ${hive.group} - hive-exec + hive-common ${hive.version} ${hive.deps.scope} + + ${hive.group} + hive-shims + + + org.apache.ant + ant + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + commons-logging commons-logging + + + + + ${hive.group} + hive-exec + + ${hive.version} + ${hive.deps.scope} + + + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-shims + + + ${hive.group} + hive-ant + + + + ${hive.group} + spark-client + + + + + ant + ant + + + org.apache.ant + ant + com.esotericsoftware.kryo kryo + + commons-codec + commons-codec + + + commons-httpclient + commons-httpclient + org.apache.avro avro-mapred + + + org.apache.calcite + calcite-core + + + org.apache.curator + apache-curator + + + org.apache.curator + curator-client + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + ${hive.group} hive-jdbc ${hive.version} - ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-common + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-service + + + ${hive.group} + hive-shims + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + + + commons-logging + commons-logging + + + ${hive.group} hive-metastore ${hive.version} ${hive.deps.scope} + + + ${hive.group} + hive-serde + + + ${hive.group} + hive-shims + + + org.apache.thrift + libfb303 + + + org.apache.thrift + libthrift + + + com.google.guava + guava + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + ${hive.group} hive-serde ${hive.version} ${hive.deps.scope} + + ${hive.group} + hive-common + + + ${hive.group} + hive-shims + + + commons-codec + commons-codec + + + com.google.code.findbugs + jsr305 + + + org.apache.avro + avro + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + commons-logging commons-logging + + + + + ${hive.group} + hive-service + ${hive.version} + ${hive.deps.scope} + + + ${hive.group} + hive-common + + + ${hive.group} + hive-exec + + + ${hive.group} + hive-metastore + + + ${hive.group} + hive-shims + + + commons-codec + commons-codec + + + org.apache.curator + curator-framework + + + org.apache.curator + curator-recipes + + + org.apache.thrift + libfb303 + + + org.apache.thrift + libthrift + + + + + + + ${hive.group} + hive-shims + ${hive.version} + ${hive.deps.scope} + + + com.google.guava + guava + + + org.apache.hadoop + hadoop-yarn-server-resourcemanager + + + org.apache.curator + curator-framework + + + org.apache.thrift + libthrift + + + org.apache.zookeeper + zookeeper + + + org.slf4j + slf4j-api + + + org.slf4j + slf4j-log4j12 + + + log4j + log4j + commons-logging - commons-logging-api + commons-logging @@ -1095,6 +1596,12 @@ ${parquet.version} ${parquet.test.deps.scope} + + com.twitter + parquet-hadoop-bundle + ${hive.parquet.version} + runtime + org.apache.flume flume-ng-core @@ -1135,6 +1642,125 @@ + + org.apache.calcite + calcite-core + ${calcite.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.google.guava + guava + + + com.google.code.findbugs + jsr305 + + + org.codehaus.janino + janino + + + + org.hsqldb + hsqldb + + + org.pentaho + pentaho-aggdesigner-algorithm + + + + + org.apache.calcite + calcite-avatica + ${calcite.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + + + org.codehaus.janino + janino + ${janino.version} + + + joda-time + joda-time + ${joda.version} + + + org.jodd + jodd-core + ${jodd.version} + + + org.datanucleus + datanucleus-core + ${datanucleus-core.version} + + + org.apache.thrift + libthrift + ${libthrift.version} + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.slf4j + slf4j-api + + + + + org.apache.thrift + libfb303 + ${libthrift.version} + + + org.apache.httpcomponents + httpclient + + + org.apache.httpcomponents + httpcore + + + org.slf4j + slf4j-api + + + @@ -1271,6 +1897,8 @@ false true true + + src false @@ -1305,6 +1933,8 @@ false true true + + __not_used__ diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index de762acc8fa0..0fbe795822fb 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -29,7 +29,7 @@ # SPARK_NICENESS The scheduling priority for daemons. Defaults to 0. ## -usage="Usage: spark-daemon.sh [--config ] (start|stop|status) " +usage="Usage: spark-daemon.sh [--config ] (start|stop|submit|status) " # if no args specified, show usage if [ $# -le 1 ]; then diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index f4b1cc3a4ffe..75ab575dfde8 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -66,7 +66,6 @@ org.codehaus.janino janino - 2.7.8 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala index b4cdfd9e98f6..57478931cd50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala @@ -31,6 +31,14 @@ import org.apache.spark.util.Utils abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { protected var parquetStore: File = _ + /** + * Optional path to a staging subdirectory which may be created during query processing + * (Hive does this). + * Parquet files under this directory will be ignored in [[readParquetSchema()]] + * @return an optional staging directory to ignore when scanning for parquet files. + */ + protected def stagingDir: Option[String] = None + override protected def beforeAll(): Unit = { parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") parquetStore.delete() @@ -43,7 +51,10 @@ abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with def readParquetSchema(path: String): MessageType = { val fsPath = new Path(path) val fs = fsPath.getFileSystem(configuration) - val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot(_.getPath.getName.startsWith("_")) + val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot { status => + status.getPath.getName.startsWith("_") || + stagingDir.map(status.getPath.getName.startsWith).getOrElse(false) + } val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) footers.head.getParquetMetadata.getFileMetaData.getSchema } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 73e6ccdb1eaf..2dfbcb2425a3 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -60,21 +60,31 @@ ${hive.group} hive-jdbc + + ${hive.group} + hive-service + ${hive.group} hive-beeline + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-json + + + com.sun.jersey + jersey-server + org.seleniumhq.selenium selenium-java test - - - io.netty - netty - - diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala new file mode 100644 index 000000000000..2228f651e238 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hive.service.server + +import org.apache.hive.service.server.HiveServer2.{StartOptionExecutor, ServerOptionsProcessor} + +/** + * Class to upgrade a package-private class to public, and + * implement a `process()` operation consistent with + * the behavior of older Hive versions + * @param serverName name of the hive server + */ +private[apache] class HiveServerServerOptionsProcessor(serverName: String) + extends ServerOptionsProcessor(serverName) { + + def process(args: Array[String]): Boolean = { + // A parse failure automatically triggers a system exit + val response = super.parse(args) + val executor = response.getServerOptionsExecutor() + // return true if the parsed option was to start the service + executor.isInstanceOf[StartOptionExecutor] + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index b7db80d93f85..9c047347cb58 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive.thriftserver +import java.util.Locale +import java.util.concurrent.atomic.AtomicBoolean + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -24,7 +27,7 @@ import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} -import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} +import org.apache.hive.service.server.{HiveServerServerOptionsProcessor, HiveServer2} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} @@ -65,7 +68,7 @@ object HiveThriftServer2 extends Logging { } def main(args: Array[String]) { - val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") + val optionsProcessor = new HiveServerServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { System.exit(-1) } @@ -241,9 +244,12 @@ object HiveThriftServer2 extends Logging { private[hive] class HiveThriftServer2(hiveContext: HiveContext) extends HiveServer2 with ReflectedCompositeService { + // state is tracked internally so that the server only attempts to shut down if it successfully + // started, and then once only. + private val started = new AtomicBoolean(false) override def init(hiveConf: HiveConf) { - val sparkSqlCliService = new SparkSQLCLIService(hiveContext) + val sparkSqlCliService = new SparkSQLCLIService(this, hiveContext) setSuperField(this, "cliService", sparkSqlCliService) addService(sparkSqlCliService) @@ -259,8 +265,19 @@ private[hive] class HiveThriftServer2(hiveContext: HiveContext) } private def isHTTPTransportMode(hiveConf: HiveConf): Boolean = { - val transportMode: String = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) - transportMode.equalsIgnoreCase("http") + val transportMode = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) + transportMode.toLowerCase(Locale.ENGLISH).equals("http") + } + + + override def start(): Unit = { + super.start() + started.set(true) } + override def stop(): Unit = { + if (started.getAndSet(false)) { + super.stop() + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e8758887ff3a..833bf62d47d0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -32,8 +32,7 @@ import org.apache.hive.service.cli._ import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.hive.shims.Utils import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession @@ -146,7 +145,7 @@ private[hive] class SparkExecuteStatementOperation( } else { val parentSessionState = SessionState.get() val hiveConf = getConfigForOperation() - val sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + val sparkServiceUGI = Utils.getUGI() val sessionHive = getCurrentHive() val currentSqlSession = hiveContext.currentSession @@ -174,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation( } try { - ShimLoader.getHadoopShims().doAs(sparkServiceUGI, doAsAction) + sparkServiceUGI.doAs(doAsAction) } catch { case e: Exception => setOperationException(new HiveSQLException(e)) @@ -201,7 +200,7 @@ private[hive] class SparkExecuteStatementOperation( } } - private def runInternal(): Unit = { + override def runInternal(): Unit = { statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index f66a17b20915..d3886142b388 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.hive.thriftserver import scala.collection.JavaConversions._ import java.io._ -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, Locale} -import jline.{ConsoleReader, History} +import jline.console.ConsoleReader +import jline.console.history.FileHistory import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory @@ -40,6 +41,10 @@ import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.Utils +/** + * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver + * has dropped its support. + */ private[hive] object SparkSQLCLIDriver extends Logging { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') @@ -111,16 +116,9 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Clean up after we exit Utils.addShutdownHook { () => SparkSQLEnv.stop() } + val remoteMode = isRemoteMode(sessionState) // "-h" option has been passed, so connect to Hive thrift server. - if (sessionState.getHost != null) { - sessionState.connect() - if (sessionState.isRemoteMode) { - prompt = s"[${sessionState.getHost}:${sessionState.getPort}]" + prompt - continuedPrompt = "".padTo(prompt.length, ' ') - } - } - - if (!sessionState.isRemoteMode) { + if (!remoteMode) { // Hadoop-20 and above - we need to augment classpath using hiveconf // components. // See also: code in ExecDriver.java @@ -131,6 +129,9 @@ private[hive] object SparkSQLCLIDriver extends Logging { } conf.setClassLoader(loader) Thread.currentThread().setContextClassLoader(loader) + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } val cli = new SparkSQLCLIDriver @@ -171,14 +172,14 @@ private[hive] object SparkSQLCLIDriver extends Logging { val reader = new ConsoleReader() reader.setBellEnabled(false) // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) - CliDriver.getCommandCompletor.foreach((e) => reader.addCompletor(e)) + CliDriver.getCommandCompleter.foreach((e) => reader.addCompleter(e)) val historyDirectory = System.getProperty("user.home") try { if (new File(historyDirectory).exists()) { val historyFile = historyDirectory + File.separator + ".hivehistory" - reader.setHistory(new History(new File(historyFile))) + reader.setHistory(new FileHistory(new File(historyFile))) } else { logWarning("WARNING: Directory for Hive history file: " + historyDirectory + " does not exist. History will not be available during this session.") @@ -190,10 +191,14 @@ private[hive] object SparkSQLCLIDriver extends Logging { logWarning(e.getMessage) } + // TODO: missing +/* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") clientTransportTSocketField.setAccessible(true) transport = clientTransportTSocketField.get(sessionState).asInstanceOf[TSocket] +*/ + transport = null var ret = 0 var prefix = "" @@ -230,6 +235,13 @@ private[hive] object SparkSQLCLIDriver extends Logging { System.exit(ret) } + + + def isRemoteMode(state: CliSessionState): Boolean = { + // sessionState.isRemoteMode + state.isHiveServerQuery + } + } private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { @@ -239,25 +251,33 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) + private val isRemoteMode = { + SparkSQLCLIDriver.isRemoteMode(sessionState) + } + private val conf: Configuration = if (sessionState != null) sessionState.getConf else new Configuration() // Force initializing SparkSQLEnv. This is put here but not object SparkSQLCliDriver // because the Hive unit tests do not go through the main() code path. - if (!sessionState.isRemoteMode) { + if (!isRemoteMode) { SparkSQLEnv.init() + } else { + // Hive 1.2 + not supported in CLI + throw new RuntimeException("Remote operations not supported") } override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() - if (cmd_trimmed.toLowerCase.equals("quit") || - cmd_trimmed.toLowerCase.equals("exit") || - tokens(0).equalsIgnoreCase("source") || + if (cmd_lower.equals("quit") || + cmd_lower.equals("exit") || + tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || cmd_trimmed.startsWith("!") || tokens(0).toLowerCase.equals("list") || - sessionState.isRemoteMode) { + isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) val end = System.currentTimeMillis() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 41f647d5f8c5..644165acf70a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -23,11 +23,12 @@ import javax.security.auth.login.LoginException import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.hive.shims.Utils import org.apache.hadoop.security.UserGroupInformation import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ +import org.apache.hive.service.server.HiveServer2 import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext @@ -35,22 +36,22 @@ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import scala.collection.JavaConversions._ -private[hive] class SparkSQLCLIService(hiveContext: HiveContext) - extends CLIService +private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, hiveContext: HiveContext) + extends CLIService(hiveServer) with ReflectedCompositeService { override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) - val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) + val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, hiveContext) setSuperField(this, "sessionManager", sparkSqlSessionManager) addService(sparkSqlSessionManager) var sparkServiceUGI: UserGroupInformation = null - if (ShimLoader.getHadoopShims.isSecurityEnabled) { + if (UserGroupInformation.isSecurityEnabled) { try { HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + sparkServiceUGI = Utils.getUGI() setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 2d5ee6800228..92ac0ec3fca2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -25,14 +25,15 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.SessionHandle import org.apache.hive.service.cli.session.SessionManager import org.apache.hive.service.cli.thrift.TProtocolVersion +import org.apache.hive.service.server.HiveServer2 import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager +private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: HiveContext) + extends SessionManager(hiveServer) with ReflectedCompositeService { private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) @@ -55,12 +56,14 @@ private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) protocol: TProtocolVersion, username: String, passwd: String, + ipAddress: String, sessionConf: java.util.Map[String, String], withImpersonation: Boolean, delegationToken: String): SessionHandle = { hiveContext.openSession() - val sessionHandle = super.openSession( - protocol, username, passwd, sessionConf, withImpersonation, delegationToken) + val sessionHandle = + super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation, + delegationToken) val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index df80d04b4080..121b3e077f71 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} +import scala.util.Failure import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfter @@ -37,31 +38,46 @@ import org.apache.spark.util.Utils class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() + val scratchDirPath = Utils.createTempDir() before { - warehousePath.delete() - metastorePath.delete() + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() } after { - warehousePath.delete() - metastorePath.delete() + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() } + /** + * Run a CLI operation and expect all the queries and expected answers to be returned. + * @param timeout maximum time for the commands to complete + * @param extraArgs any extra arguments + * @param errorResponses a sequence of strings whose presence in the stdout of the forked process + * is taken as an immediate error condition. That is: if a line beginning + * with one of these strings is found, fail the test immediately. + * The default value is `Seq("Error:")` + * + * @param queriesAndExpectedAnswers one or more tupes of query + answer + */ def runCliWithin( timeout: FiniteDuration, - extraArgs: Seq[String] = Seq.empty)( + extraArgs: Seq[String] = Seq.empty, + errorResponses: Seq[String] = Seq("Error:"))( queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) - val command = { + val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" s"""$cliScript | --master local | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath """.stripMargin.split("\\s+").toSeq ++ extraArgs } @@ -81,6 +97,12 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { if (next == expectedAnswers.size) { foundAllExpectedAnswers.trySuccess(()) } + } else { + errorResponses.foreach( r => { + if (line.startsWith(r)) { + foundAllExpectedAnswers.tryFailure( + new RuntimeException(s"Failed with error line '$line'")) + }}) } } @@ -88,16 +110,44 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val process = (Process(command, None) #< queryStream).run( ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) + // catch the output value + class exitCodeCatcher extends Runnable { + var exitValue = 0 + + override def run(): Unit = { + try { + exitValue = process.exitValue() + } catch { + case rte: RuntimeException => + // ignored as it will get triggered when the process gets destroyed + logDebug("Ignoring exception while waiting for exit code", rte) + } + if (exitValue != 0) { + // process exited: fail fast + foundAllExpectedAnswers.tryFailure( + new RuntimeException(s"Failed with exit code $exitValue")) + } + } + } + // spin off the code catche thread. No attempt is made to kill this + // as it will exit once the launched process terminates. + val codeCatcherThread = new Thread(new exitCodeCatcher()) + codeCatcherThread.start() + try { - Await.result(foundAllExpectedAnswers.future, timeout) + Await.ready(foundAllExpectedAnswers.future, timeout) + foundAllExpectedAnswers.future.value match { + case Some(Failure(t)) => throw t + case _ => + } } catch { case cause: Throwable => - logError( + val message = s""" |======================= |CliSuite failure output |======================= |Spark SQL CLI command line: ${command.mkString(" ")} - | + |Exception: $cause |Executed query $next "${queries(next)}", |But failed to capture expected output "${expectedAnswers(next)}" within $timeout. | @@ -105,8 +155,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { |=========================== |End CliSuite failure output |=========================== - """.stripMargin, cause) - throw cause + """.stripMargin + logError(message, cause) + fail(message, cause) } finally { process.destroy() } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 39b31523e07c..8374629b5d45 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.URL -import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable.ArrayBuffer @@ -492,7 +491,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl new File(s"$tempLog4jConf/log4j.properties"), UTF_8) - tempLog4jConf + File.pathSeparator + sys.props("java.class.path") + tempLog4jConf // + File.pathSeparator + sys.props("java.class.path") } s"""$startScript @@ -508,6 +507,20 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl """.stripMargin.split("\\s+").toSeq } + /** + * String to scan for when looking for the the thrift binary endpoint running. + * This can change across Hive versions. + */ + val THRIFT_BINARY_SERVICE_LIVE = "Starting ThriftBinaryCLIService on port" + + /** + * String to scan for when looking for the the thrift HTTP endpoint running. + * This can change across Hive versions. + */ + val THRIFT_HTTP_SERVICE_LIVE = "Started ThriftHttpCLIService in http" + + val SERVER_STARTUP_TIMEOUT = 1.minute + private def startThriftServer(port: Int, attempt: Int) = { warehousePath = Utils.createTempDir() warehousePath.delete() @@ -545,23 +558,26 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl // Ensures that the following "tail" command won't fail. logPath.createNewFile() + val successLines = Seq(THRIFT_BINARY_SERVICE_LIVE, THRIFT_HTTP_SERVICE_LIVE) + val failureLines = Seq("HiveServer2 is stopped", "Exception in thread", "Error:") logTailingProcess = // Using "-n +0" to make sure all lines in the log file are checked. Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger( (line: String) => { diagnosisBuffer += line - - if (line.contains("ThriftBinaryCLIService listening on") || - line.contains("Started ThriftHttpCLIService in http")) { - serverStarted.trySuccess(()) - } else if (line.contains("HiveServer2 is stopped")) { - // This log line appears when the server fails to start and terminates gracefully (e.g. - // because of port contention). - serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2")) - } + successLines.foreach(r => { + if (line.contains(r)) { + serverStarted.trySuccess(()) + } + }) + failureLines.foreach(r => { + if (line.contains(r)) { + serverStarted.tryFailure(new RuntimeException(s"Failed with output '$line'")) + } + }) })) - Await.result(serverStarted.future, 2.minute) + Await.result(serverStarted.future, SERVER_STARTUP_TIMEOUT) } private def stopThriftServer(): Unit = { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 53d5b22b527b..c46a4a4b0be5 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -267,7 +267,34 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "date_udf", // Unlike Hive, we do support log base in (0, 1.0], therefore disable this - "udf7" + "udf7", + + // Trivial changes to DDL output + "compute_stats_empty_table", + "compute_stats_long", + "create_view_translate", + "show_create_table_serde", + "show_tblproperties", + + // Odd changes to output + "merge4", + + // Thift is broken... + "inputddl8", + + // Hive changed ordering of ddl: + "varchar_union1", + + // Parser changes in Hive 1.2 + "input25", + "input26", + + // Uses invalid table name + "innerjoin", + + // classpath problems + "compute_stats.*", + "udf_bitmap_.*" ) /** diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index b00f320318be..be1607476e25 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -36,6 +36,11 @@ + + + com.twitter + parquet-hadoop-bundle + org.apache.spark spark-core_${scala.binary.version} @@ -53,32 +58,42 @@ spark-sql_${scala.binary.version} ${project.version} + - org.codehaus.jackson - jackson-mapper-asl + ${hive.group} + hive-exec + ${hive.group} - hive-serde + hive-metastore + org.apache.avro @@ -91,6 +106,55 @@ avro-mapred ${avro.mapred.classifier} + + commons-httpclient + commons-httpclient + + + org.apache.calcite + calcite-avatica + + + org.apache.calcite + calcite-core + + + org.apache.httpcomponents + httpclient + + + org.codehaus.jackson + jackson-mapper-asl + + + + commons-codec + commons-codec + + + joda-time + joda-time + + + org.jodd + jodd-core + + + com.google.code.findbugs + jsr305 + + + org.datanucleus + datanucleus-core + + + org.apache.thrift + libthrift + + + org.apache.thrift + libfb303 + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 110f51a30586..567d7fa12ff1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -20,15 +20,18 @@ package org.apache.spark.sql.hive import java.io.File import java.net.{URL, URLClassLoader} import java.sql.Timestamp +import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.language.implicitConversions +import scala.concurrent.duration._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.session.SessionState @@ -164,6 +167,16 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { } SessionState.setCurrentSessionState(executionHive.state) + /** + * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. + * - allow SQL11 keywords to be used as identifiers + */ + private[sql] def defaultOverides() = { + setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") + } + + defaultOverides() + /** * The copy of the Hive client that is used to retrieve metadata from the Hive MetaStore. * The version of the Hive client that is used here must match the metastore that is configured @@ -252,6 +265,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { } protected[sql] override def parseSql(sql: String): LogicalPlan = { + var state = SessionState.get() + if (state == null) { + SessionState.setCurrentSessionState(tlSession.get().asInstanceOf[SQLSession].sessionState) + } super.parseSql(substitutor.substitute(hiveconf, sql)) } @@ -298,10 +315,21 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // Can we use fs.getContentSummary in future? // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use // countFileSize to count the table size. + val stagingDir = metadataHive.getConf(HiveConf.ConfVars.STAGINGDIR.varname, + HiveConf.ConfVars.STAGINGDIR.defaultStrVal) + def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDir) { - fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum + fs.listStatus(path) + .map { status => + if (!status.getPath().getName().startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + } + .sum } else { fileStatus.getLen } @@ -398,7 +426,58 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { } /** Overridden by child classes that need to set configuration before the client init. */ - protected def configure(): Map[String, String] = Map.empty + protected def configure(): Map[String, String] = { + // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch + // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- + // compatibility when users are trying to connecting to a Hive metastore of lower version, + // because these options are expected to be integral values in lower versions of Hive. + // + // Here we enumerate all time `ConfVar`s and convert their values to numeric strings according + // to their output time units. + Seq( + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.METASTORE_CLIENT_SOCKET_LIFETIME -> TimeUnit.SECONDS, + ConfVars.HMSHANDLERINTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_EVENT_DB_LISTENER_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_CLEAN_FREQ -> TimeUnit.SECONDS, + ConfVars.METASTORE_EVENT_EXPIRY_DURATION -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_TTL -> TimeUnit.SECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_WRITER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.METASTORE_AGGREGATE_STATS_CACHE_MAX_READER_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVES_AUTO_PROGRESS_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOG_INCREMENTAL_PLAN_PROGRESS_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_STATS_JDBC_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_STATS_RETRIES_WAIT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_LOCK_SLEEP_BETWEEN_RETRIES -> TimeUnit.SECONDS, + ConfVars.HIVE_ZOOKEEPER_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_ZOOKEEPER_CONNECTION_BASESLEEPTIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_TXN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_WORKER_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CHECK_INTERVAL -> TimeUnit.SECONDS, + ConfVars.HIVE_COMPACTOR_CLEANER_RUN_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_BEBACKOFF_SLOT_LENGTH -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_THRIFT_LOGIN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_THRIFT_WORKER_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_SHUTDOWN_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_ASYNC_EXEC_KEEPALIVE_TIME -> TimeUnit.SECONDS, + ConfVars.HIVE_SERVER2_LONG_POLLING_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_SESSION_CHECK_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_SESSION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.HIVE_SERVER2_IDLE_OPERATION_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SERVER_READ_SOCKET_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.HIVE_LOCALIZE_RESOURCE_WAIT_INTERVAL -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_CLIENT_FUTURE_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_JOB_MONITOR_TIMEOUT -> TimeUnit.SECONDS, + ConfVars.SPARK_RPC_CLIENT_CONNECT_TIMEOUT -> TimeUnit.MILLISECONDS, + ConfVars.SPARK_RPC_CLIENT_HANDSHAKE_TIMEOUT -> TimeUnit.MILLISECONDS + ).map { case (confVar, unit) => + confVar.varname -> hiveconf.getTimeVar(confVar, unit).toString + }.toMap + } protected[hive] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { @@ -515,19 +594,23 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { private[hive] object HiveContext { /** The version of hive used internally by Spark SQL. */ - val hiveExecutionVersion: String = "0.13.1" + val hiveExecutionVersion: String = "1.2.1" val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", defaultValue = Some("builtin"), - doc = "Location of the jars that should be used to instantiate the HiveMetastoreClient. This" + - " property can be one of three options: " + - "1. \"builtin\" Use Hive 0.13.1, which is bundled with the Spark assembly jar when " + - "-Phive is enabled. When this option is chosen, " + - "spark.sql.hive.metastore.version must be either 0.13.1 or not defined. " + - "2. \"maven\" Use Hive jars of specified version downloaded from Maven repositories." + - "3. A classpath in the standard format for both Hive and Hadoop.") - + doc = s""" + | Location of the jars that should be used to instantiate the HiveMetastoreClient. + | This property can be one of three options: " + | 1. "builtin" + | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly jar when + | -Phive is enabled. When this option is chosen, + | spark.sql.hive.metastore.version must be either + | ${hiveExecutionVersion} or not defined. + | 2. "maven" + | Use Hive jars of specified version downloaded from Maven repositories. + | 3. A classpath in the standard format for both Hive and Hadoop. + """.stripMargin) val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", defaultValue = Some(true), doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + @@ -566,17 +649,18 @@ private[hive] object HiveContext { /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ def newTemporaryConfiguration(): Map[String, String] = { val tempDir = Utils.createTempDir() - val localMetastore = new File(tempDir, "metastore").getAbsolutePath + val localMetastore = new File(tempDir, "metastore") val propMap: HashMap[String, String] = HashMap() // We have to mask all properties in hive-site.xml that relates to metastore data source // as we used a local metastore here. HiveConf.ConfVars.values().foreach { confvar => if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { - propMap.put(confvar.varname, confvar.defaultVal) + propMap.put(confvar.varname, confvar.getDefaultExpr()) } } - propMap.put("javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$localMetastore;create=true") + propMap.put(HiveConf.ConfVars.METASTOREWAREHOUSE.varname, localMetastore.toURI.toString) + propMap.put(HiveConf.ConfVars.METASTORECONNECTURLKEY.varname, + s"jdbc:derby:;databaseName=${localMetastore.getAbsolutePath};create=true") propMap.put("datanucleus.rdbms.datastoreAdapterClassName", "org.datanucleus.store.rdbms.adapter.DerbyAdapter") propMap.toMap diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index a8c9b4fa71b9..16c186627f6c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -649,11 +649,12 @@ private[hive] case class MetastoreRelation table.outputFormat.foreach(sd.setOutputFormat) val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - sd.setSerdeInfo(serdeInfo) table.serde.foreach(serdeInfo.setSerializationLib) + sd.setSerdeInfo(serdeInfo) + val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Table(tTable) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e6df64d2642b..e2fdfc6163a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.sql.Date +import java.util.Locale import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.serde.serdeConstants @@ -80,6 +81,7 @@ private[hive] object HiveQl extends Logging { "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", + "TOK_ALTERTABLE", "TOK_ALTERTABLE_ADDCOLS", "TOK_ALTERTABLE_ADDPARTS", "TOK_ALTERTABLE_ALTERPARTS", @@ -94,6 +96,7 @@ private[hive] object HiveQl extends Logging { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", + "TOK_ALTERVIEW", "TOK_ALTERVIEW_ADDPARTS", "TOK_ALTERVIEW_AS", "TOK_ALTERVIEW_DROPPARTS", @@ -248,7 +251,7 @@ private[hive] object HiveQl extends Logging { * Otherwise, there will be Null pointer exception, * when retrieving properties form HiveConf. */ - val hContext = new Context(hiveConf) + val hContext = new Context(SessionState.get().getConf()) val node = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, hContext)) hContext.clear() node @@ -577,12 +580,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_TABLESKEWED", // Skewed by "TOK_TABLEROWFORMAT", "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive. - "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile - "TOK_TBLTEXTFILE", // Stored as TextFile - "TOK_TBLRCFILE", // Stored as RCFile - "TOK_TBLORCFILE", // Stored as ORC File - "TOK_TBLPARQUETFILE", // Stored as PARQUET + "TOK_FILEFORMAT_GENERIC", "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat "TOK_STORAGEHANDLER", // Storage handler "TOK_TABLELOCATION", @@ -706,36 +704,51 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) } case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - throw new SemanticException( - "Unrecognized file format in STORED AS clause:${child.getText}") + child.getText().toLowerCase(Locale.ENGLISH) match { + case "orc" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } - case Token("TOK_TBLRCFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } + case "parquet" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } - case Token("TOK_TBLORCFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } + case "rcfile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + } - case Token("TOK_TBLPARQUETFILE", Nil) => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + case "textfile" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + + case "sequencefile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + + case _ => + throw new SemanticException( + s"Unrecognized file format in STORED AS clause: ${child.getText}") } case Token("TOK_TABLESERIALIZER", @@ -751,7 +764,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLEPROPERTIES", list :: Nil) => tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", _) => + case list @ Token("TOK_TABLEFILEFORMAT", children) => tableDesc = tableDesc.copy( inputFormat = Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), @@ -889,7 +902,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (name, value) + (BaseSemanticAnalyzer.unescapeSQLString(name), + BaseSemanticAnalyzer.unescapeSQLString(value)) } (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) @@ -1037,10 +1051,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // return With plan if there is CTE cteRelations.map(With(query, _)).getOrElse(query) - case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) + // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT + case Token("TOK_UNIONALL", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") } val allJoinTokens = "(TOK_.*JOIN)".r @@ -1251,7 +1266,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName}:" + + s"\n ${dumpTree(a).toString} ") } protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { @@ -1274,7 +1290,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_HINTLIST", _) => None case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + throw new NotImplementedError(s"No parse rules for ${a.getName }:" + + s"\n ${dumpTree(a).toString } ") } protected val escapedIdentifier = "`([^`]+)`".r diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index a357bb39ca7f..267074f3ad10 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive import java.io.{InputStream, OutputStream} import java.rmi.server.UID +import org.apache.avro.Schema + /* Implicit conversions */ import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -33,7 +35,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} import org.apache.hadoop.hive.serde2.ColumnProjectionUtils -import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector import org.apache.hadoop.io.Writable @@ -82,10 +84,19 @@ private[hive] object HiveShim { * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that * is needed to initialize before serialization. */ - def prepareWritable(w: Writable): Writable = { + def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = { w match { case w: AvroGenericRecordWritable => w.setRecordReaderID(new UID()) + // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will + // be thrown. + if (w.getFileSchema() == null) { + serDeProps + .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName()) + .foreach { kv => + w.setFileSchema(new Schema.Parser().parse(kv._2)) + } + } case _ => } w diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index d834b4e83e04..a82e152dcda2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -87,6 +87,10 @@ private[hive] case class HiveTable( * shared classes. */ private[hive] trait ClientInterface { + + /** Returns the configuration for the given key in the current session. */ + def getConf(key: String, defaultValue: String): String + /** * Runs a HiveQL command using Hive, returning the results as a list of strings. Each row will * result in one string. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 6e0912da5862..dc372be0e5a3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -38,7 +38,6 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} - /** * A class that wraps the HiveClient and converts its responses to externally visible classes. * Note that this class is typically loaded with an internal classloader for each instantiation, @@ -115,6 +114,10 @@ private[hive] class ClientWrapper( /** Returns the configuration for the current session. */ def conf: HiveConf = SessionState.get().getConf + override def getConf(key: String, defaultValue: String): String = { + conf.get(key, defaultValue) + } + // TODO: should be a def?s // When we create this val client, the HiveConf of it (conf) is the one associated with state. @GuardedBy("this") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 956997e5f9dc..6e826ce55220 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -512,7 +512,7 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, - 0: JLong) + 0L: JLong) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 97fb98199991..f58bc7d7a0af 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -55,7 +55,7 @@ private[hive] object IsolatedClientLoader { case "14" | "0.14" | "0.14.0" => hive.v14 case "1.0" | "1.0.0" => hive.v1_0 case "1.1" | "1.1.0" => hive.v1_1 - case "1.2" | "1.2.0" => hive.v1_2 + case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 } private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index b48082fe4b36..0503691a4424 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -56,7 +56,7 @@ package object client { "net.hydromatic:linq4j", "net.hydromatic:quidem")) - case object v1_2 extends HiveVersion("1.2.0", + case object v1_2 extends HiveVersion("1.2.1", exclusions = Seq("eigenbase:eigenbase-properties", "org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 40a6a3215668..12c667e6e92d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -129,7 +129,7 @@ case class InsertIntoHiveTable( // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation - val tmpLocation = hiveContext.getExternalTmpPath(tableLocation.toUri) + val tmpLocation = hiveContext.getExternalTmpPath(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) val isCompressed = sc.hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 7e3342cc84c0..fbb86406f40c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -247,7 +247,7 @@ private class ScriptTransformationWriterThread( } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) } } outputStream.close() @@ -345,9 +345,7 @@ case class HiveScriptIOSchema ( val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - var propsMap = serdeProps.map(kv => { - (kv._1.split("'")(1), kv._2.split("'")(1)) - }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) val properties = new Properties() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index abe5c6900313..8a86a87368f2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -249,7 +249,7 @@ private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { // Get the class of this function. // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. - val functionClass = windowFunctionInfo.getfInfo().getFunctionClass + val functionClass = windowFunctionInfo.getFunctionClass() val newChildren = // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit // input parameters and requires implicit parameters, which diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8850e060d2a7..684ea1d137b4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -171,7 +171,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( import SparkHiveDynamicPartitionWriterContainer._ private val defaultPartName = jobConf.get( - ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultStrVal) @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index ddd5d24717ad..86142e5d66f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.ql.io.sarg.SearchArgument +import org.apache.hadoop.hive.ql.io.sarg.{SearchArgumentFactory, SearchArgument} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.serde2.io.DateWritable @@ -33,13 +33,13 @@ import org.apache.spark.sql.sources._ private[orc] object OrcFilters extends Logging { def createFilter(expr: Array[Filter]): Option[SearchArgument] = { expr.reduceOption(And).flatMap { conjunction => - val builder = SearchArgument.FACTORY.newBuilder() + val builder = SearchArgumentFactory.newBuilder() buildSearchArgument(conjunction, builder).map(_.build()) } } private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { - def newBuilder = SearchArgument.FACTORY.newBuilder() + def newBuilder = SearchArgumentFactory.newBuilder() def isSearchableLiteral(value: Any): Boolean = value match { // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7bbdef90cd6b..8d0bf46e8fad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -20,29 +20,25 @@ package org.apache.spark.sql.hive.test import java.io.File import java.util.{Set => JavaSet} -import org.apache.hadoop.hive.conf.HiveConf +import scala.collection.mutable +import scala.language.implicitConversions + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} -import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe -import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.sql.SQLConf import org.apache.spark.util.Utils import org.apache.spark.{SparkConf, SparkContext} -import scala.collection.mutable -import scala.language.implicitConversions - /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -83,15 +79,25 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveconf.set("hive.plan.serialization.format", "javaXML") - lazy val warehousePath = Utils.createTempDir() + lazy val warehousePath = Utils.createTempDir(namePrefix = "warehouse-") + + lazy val scratchDirPath = { + val dir = Utils.createTempDir(namePrefix = "scratch-") + dir.delete() + dir + } private lazy val temporaryConfig = newTemporaryConfiguration() /** Sets up the system initially or after a RESET command */ - protected override def configure(): Map[String, String] = - temporaryConfig ++ Map( - ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toString, - ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true") + protected override def configure(): Map[String, String] = { + super.configure() ++ temporaryConfig ++ Map( + ConfVars.METASTOREWAREHOUSE.varname -> warehousePath.toURI.toString, + ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", + ConfVars.SCRATCHDIR.varname -> scratchDirPath.toURI.toString, + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1" + ) + } val testTempDir = Utils.createTempDir() @@ -244,7 +250,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { }), TestTable("src_thrift", () => { import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer - import org.apache.hadoop.hive.serde2.thrift.test.Complex import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol @@ -253,7 +258,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' |WITH SERDEPROPERTIES( - | 'serialization.class'='${classOf[Complex].getName}', + | 'serialization.class'='org.apache.spark.sql.hive.test.Complex', | 'serialization.format'='${classOf[TBinaryProtocol].getName}' |) |STORED AS @@ -437,6 +442,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { case (k, v) => metadataHive.runSqlHive(s"SET $k=$v") } + defaultOverides() runSqlHive("USE default") diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java new file mode 100644 index 000000000000..e010112bb932 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -0,0 +1,1139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.test; + +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; + +import org.apache.hadoop.hive.serde2.thrift.test.IntString; +import org.apache.thrift.scheme.TupleScheme; +import org.apache.thrift.protocol.TTupleProtocol; +import org.apache.thrift.EncodingUtils; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.EnumMap; +import java.util.EnumSet; +import java.util.Collections; +import java.util.BitSet; + +/** + * This is a fork of Hive 0.13's org/apache/hadoop/hive/serde2/thrift/test/Complex.java, which + * does not contain union fields that are not supported by Spark SQL. + */ + +@SuppressWarnings({"ALL", "unchecked"}) +public class Complex implements org.apache.thrift.TBase, java.io.Serializable, Cloneable { + private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Complex"); + + private static final org.apache.thrift.protocol.TField AINT_FIELD_DESC = new org.apache.thrift.protocol.TField("aint", org.apache.thrift.protocol.TType.I32, (short)1); + private static final org.apache.thrift.protocol.TField A_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("aString", org.apache.thrift.protocol.TType.STRING, (short)2); + private static final org.apache.thrift.protocol.TField LINT_FIELD_DESC = new org.apache.thrift.protocol.TField("lint", org.apache.thrift.protocol.TType.LIST, (short)3); + private static final org.apache.thrift.protocol.TField L_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lString", org.apache.thrift.protocol.TType.LIST, (short)4); + private static final org.apache.thrift.protocol.TField LINT_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("lintString", org.apache.thrift.protocol.TType.LIST, (short)5); + private static final org.apache.thrift.protocol.TField M_STRING_STRING_FIELD_DESC = new org.apache.thrift.protocol.TField("mStringString", org.apache.thrift.protocol.TType.MAP, (short)6); + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new ComplexStandardSchemeFactory()); + schemes.put(TupleScheme.class, new ComplexTupleSchemeFactory()); + } + + private int aint; // required + private String aString; // required + private List lint; // required + private List lString; // required + private List lintString; // required + private Map mStringString; // required + + /** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */ + public enum _Fields implements org.apache.thrift.TFieldIdEnum { + AINT((short)1, "aint"), + A_STRING((short)2, "aString"), + LINT((short)3, "lint"), + L_STRING((short)4, "lString"), + LINT_STRING((short)5, "lintString"), + M_STRING_STRING((short)6, "mStringString"); + + private static final Map byName = new HashMap(); + + static { + for (_Fields field : EnumSet.allOf(_Fields.class)) { + byName.put(field.getFieldName(), field); + } + } + + /** + * Find the _Fields constant that matches fieldId, or null if its not found. + */ + public static _Fields findByThriftId(int fieldId) { + switch(fieldId) { + case 1: // AINT + return AINT; + case 2: // A_STRING + return A_STRING; + case 3: // LINT + return LINT; + case 4: // L_STRING + return L_STRING; + case 5: // LINT_STRING + return LINT_STRING; + case 6: // M_STRING_STRING + return M_STRING_STRING; + default: + return null; + } + } + + /** + * Find the _Fields constant that matches fieldId, throwing an exception + * if it is not found. + */ + public static _Fields findByThriftIdOrThrow(int fieldId) { + _Fields fields = findByThriftId(fieldId); + if (fields == null) throw new IllegalArgumentException("Field " + fieldId + " doesn't exist!"); + return fields; + } + + /** + * Find the _Fields constant that matches name, or null if its not found. + */ + public static _Fields findByName(String name) { + return byName.get(name); + } + + private final short _thriftId; + private final String _fieldName; + + _Fields(short thriftId, String fieldName) { + _thriftId = thriftId; + _fieldName = fieldName; + } + + public short getThriftFieldId() { + return _thriftId; + } + + public String getFieldName() { + return _fieldName; + } + } + + // isset id assignments + private static final int __AINT_ISSET_ID = 0; + private byte __isset_bitfield = 0; + public static final Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> metaDataMap; + static { + Map<_Fields, org.apache.thrift.meta_data.FieldMetaData> tmpMap = new EnumMap<_Fields, org.apache.thrift.meta_data.FieldMetaData>(_Fields.class); + tmpMap.put(_Fields.AINT, new org.apache.thrift.meta_data.FieldMetaData("aint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32))); + tmpMap.put(_Fields.A_STRING, new org.apache.thrift.meta_data.FieldMetaData("aString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING))); + tmpMap.put(_Fields.LINT, new org.apache.thrift.meta_data.FieldMetaData("lint", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.I32)))); + tmpMap.put(_Fields.L_STRING, new org.apache.thrift.meta_data.FieldMetaData("lString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + tmpMap.put(_Fields.LINT_STRING, new org.apache.thrift.meta_data.FieldMetaData("lintString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.ListMetaData(org.apache.thrift.protocol.TType.LIST, + new org.apache.thrift.meta_data.StructMetaData(org.apache.thrift.protocol.TType.STRUCT, IntString.class)))); + tmpMap.put(_Fields.M_STRING_STRING, new org.apache.thrift.meta_data.FieldMetaData("mStringString", org.apache.thrift.TFieldRequirementType.DEFAULT, + new org.apache.thrift.meta_data.MapMetaData(org.apache.thrift.protocol.TType.MAP, + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING), + new org.apache.thrift.meta_data.FieldValueMetaData(org.apache.thrift.protocol.TType.STRING)))); + metaDataMap = Collections.unmodifiableMap(tmpMap); + org.apache.thrift.meta_data.FieldMetaData.addStructMetaDataMap(Complex.class, metaDataMap); + } + + public Complex() { + } + + public Complex( + int aint, + String aString, + List lint, + List lString, + List lintString, + Map mStringString) + { + this(); + this.aint = aint; + setAintIsSet(true); + this.aString = aString; + this.lint = lint; + this.lString = lString; + this.lintString = lintString; + this.mStringString = mStringString; + } + + /** + * Performs a deep copy on other. + */ + public Complex(Complex other) { + __isset_bitfield = other.__isset_bitfield; + this.aint = other.aint; + if (other.isSetAString()) { + this.aString = other.aString; + } + if (other.isSetLint()) { + List __this__lint = new ArrayList(); + for (Integer other_element : other.lint) { + __this__lint.add(other_element); + } + this.lint = __this__lint; + } + if (other.isSetLString()) { + List __this__lString = new ArrayList(); + for (String other_element : other.lString) { + __this__lString.add(other_element); + } + this.lString = __this__lString; + } + if (other.isSetLintString()) { + List __this__lintString = new ArrayList(); + for (IntString other_element : other.lintString) { + __this__lintString.add(new IntString(other_element)); + } + this.lintString = __this__lintString; + } + if (other.isSetMStringString()) { + Map __this__mStringString = new HashMap(); + for (Map.Entry other_element : other.mStringString.entrySet()) { + + String other_element_key = other_element.getKey(); + String other_element_value = other_element.getValue(); + + String __this__mStringString_copy_key = other_element_key; + + String __this__mStringString_copy_value = other_element_value; + + __this__mStringString.put(__this__mStringString_copy_key, __this__mStringString_copy_value); + } + this.mStringString = __this__mStringString; + } + } + + public Complex deepCopy() { + return new Complex(this); + } + + @Override + public void clear() { + setAintIsSet(false); + this.aint = 0; + this.aString = null; + this.lint = null; + this.lString = null; + this.lintString = null; + this.mStringString = null; + } + + public int getAint() { + return this.aint; + } + + public void setAint(int aint) { + this.aint = aint; + setAintIsSet(true); + } + + public void unsetAint() { + __isset_bitfield = EncodingUtils.clearBit(__isset_bitfield, __AINT_ISSET_ID); + } + + /** Returns true if field aint is set (has been assigned a value) and false otherwise */ + public boolean isSetAint() { + return EncodingUtils.testBit(__isset_bitfield, __AINT_ISSET_ID); + } + + public void setAintIsSet(boolean value) { + __isset_bitfield = EncodingUtils.setBit(__isset_bitfield, __AINT_ISSET_ID, value); + } + + public String getAString() { + return this.aString; + } + + public void setAString(String aString) { + this.aString = aString; + } + + public void unsetAString() { + this.aString = null; + } + + /** Returns true if field aString is set (has been assigned a value) and false otherwise */ + public boolean isSetAString() { + return this.aString != null; + } + + public void setAStringIsSet(boolean value) { + if (!value) { + this.aString = null; + } + } + + public int getLintSize() { + return (this.lint == null) ? 0 : this.lint.size(); + } + + public java.util.Iterator getLintIterator() { + return (this.lint == null) ? null : this.lint.iterator(); + } + + public void addToLint(int elem) { + if (this.lint == null) { + this.lint = new ArrayList<>(); + } + this.lint.add(elem); + } + + public List getLint() { + return this.lint; + } + + public void setLint(List lint) { + this.lint = lint; + } + + public void unsetLint() { + this.lint = null; + } + + /** Returns true if field lint is set (has been assigned a value) and false otherwise */ + public boolean isSetLint() { + return this.lint != null; + } + + public void setLintIsSet(boolean value) { + if (!value) { + this.lint = null; + } + } + + public int getLStringSize() { + return (this.lString == null) ? 0 : this.lString.size(); + } + + public java.util.Iterator getLStringIterator() { + return (this.lString == null) ? null : this.lString.iterator(); + } + + public void addToLString(String elem) { + if (this.lString == null) { + this.lString = new ArrayList(); + } + this.lString.add(elem); + } + + public List getLString() { + return this.lString; + } + + public void setLString(List lString) { + this.lString = lString; + } + + public void unsetLString() { + this.lString = null; + } + + /** Returns true if field lString is set (has been assigned a value) and false otherwise */ + public boolean isSetLString() { + return this.lString != null; + } + + public void setLStringIsSet(boolean value) { + if (!value) { + this.lString = null; + } + } + + public int getLintStringSize() { + return (this.lintString == null) ? 0 : this.lintString.size(); + } + + public java.util.Iterator getLintStringIterator() { + return (this.lintString == null) ? null : this.lintString.iterator(); + } + + public void addToLintString(IntString elem) { + if (this.lintString == null) { + this.lintString = new ArrayList<>(); + } + this.lintString.add(elem); + } + + public List getLintString() { + return this.lintString; + } + + public void setLintString(List lintString) { + this.lintString = lintString; + } + + public void unsetLintString() { + this.lintString = null; + } + + /** Returns true if field lintString is set (has been assigned a value) and false otherwise */ + public boolean isSetLintString() { + return this.lintString != null; + } + + public void setLintStringIsSet(boolean value) { + if (!value) { + this.lintString = null; + } + } + + public int getMStringStringSize() { + return (this.mStringString == null) ? 0 : this.mStringString.size(); + } + + public void putToMStringString(String key, String val) { + if (this.mStringString == null) { + this.mStringString = new HashMap(); + } + this.mStringString.put(key, val); + } + + public Map getMStringString() { + return this.mStringString; + } + + public void setMStringString(Map mStringString) { + this.mStringString = mStringString; + } + + public void unsetMStringString() { + this.mStringString = null; + } + + /** Returns true if field mStringString is set (has been assigned a value) and false otherwise */ + public boolean isSetMStringString() { + return this.mStringString != null; + } + + public void setMStringStringIsSet(boolean value) { + if (!value) { + this.mStringString = null; + } + } + + public void setFieldValue(_Fields field, Object value) { + switch (field) { + case AINT: + if (value == null) { + unsetAint(); + } else { + setAint((Integer)value); + } + break; + + case A_STRING: + if (value == null) { + unsetAString(); + } else { + setAString((String)value); + } + break; + + case LINT: + if (value == null) { + unsetLint(); + } else { + setLint((List)value); + } + break; + + case L_STRING: + if (value == null) { + unsetLString(); + } else { + setLString((List)value); + } + break; + + case LINT_STRING: + if (value == null) { + unsetLintString(); + } else { + setLintString((List)value); + } + break; + + case M_STRING_STRING: + if (value == null) { + unsetMStringString(); + } else { + setMStringString((Map)value); + } + break; + + } + } + + public Object getFieldValue(_Fields field) { + switch (field) { + case AINT: + return Integer.valueOf(getAint()); + + case A_STRING: + return getAString(); + + case LINT: + return getLint(); + + case L_STRING: + return getLString(); + + case LINT_STRING: + return getLintString(); + + case M_STRING_STRING: + return getMStringString(); + + } + throw new IllegalStateException(); + } + + /** Returns true if field corresponding to fieldID is set (has been assigned a value) and false otherwise */ + public boolean isSet(_Fields field) { + if (field == null) { + throw new IllegalArgumentException(); + } + + switch (field) { + case AINT: + return isSetAint(); + case A_STRING: + return isSetAString(); + case LINT: + return isSetLint(); + case L_STRING: + return isSetLString(); + case LINT_STRING: + return isSetLintString(); + case M_STRING_STRING: + return isSetMStringString(); + } + throw new IllegalStateException(); + } + + @Override + public boolean equals(Object that) { + if (that == null) + return false; + if (that instanceof Complex) + return this.equals((Complex)that); + return false; + } + + public boolean equals(Complex that) { + if (that == null) + return false; + + boolean this_present_aint = true; + boolean that_present_aint = true; + if (this_present_aint || that_present_aint) { + if (!(this_present_aint && that_present_aint)) + return false; + if (this.aint != that.aint) + return false; + } + + boolean this_present_aString = true && this.isSetAString(); + boolean that_present_aString = true && that.isSetAString(); + if (this_present_aString || that_present_aString) { + if (!(this_present_aString && that_present_aString)) + return false; + if (!this.aString.equals(that.aString)) + return false; + } + + boolean this_present_lint = true && this.isSetLint(); + boolean that_present_lint = true && that.isSetLint(); + if (this_present_lint || that_present_lint) { + if (!(this_present_lint && that_present_lint)) + return false; + if (!this.lint.equals(that.lint)) + return false; + } + + boolean this_present_lString = true && this.isSetLString(); + boolean that_present_lString = true && that.isSetLString(); + if (this_present_lString || that_present_lString) { + if (!(this_present_lString && that_present_lString)) + return false; + if (!this.lString.equals(that.lString)) + return false; + } + + boolean this_present_lintString = true && this.isSetLintString(); + boolean that_present_lintString = true && that.isSetLintString(); + if (this_present_lintString || that_present_lintString) { + if (!(this_present_lintString && that_present_lintString)) + return false; + if (!this.lintString.equals(that.lintString)) + return false; + } + + boolean this_present_mStringString = true && this.isSetMStringString(); + boolean that_present_mStringString = true && that.isSetMStringString(); + if (this_present_mStringString || that_present_mStringString) { + if (!(this_present_mStringString && that_present_mStringString)) + return false; + if (!this.mStringString.equals(that.mStringString)) + return false; + } + + return true; + } + + @Override + public int hashCode() { + HashCodeBuilder builder = new HashCodeBuilder(); + + boolean present_aint = true; + builder.append(present_aint); + if (present_aint) + builder.append(aint); + + boolean present_aString = true && (isSetAString()); + builder.append(present_aString); + if (present_aString) + builder.append(aString); + + boolean present_lint = true && (isSetLint()); + builder.append(present_lint); + if (present_lint) + builder.append(lint); + + boolean present_lString = true && (isSetLString()); + builder.append(present_lString); + if (present_lString) + builder.append(lString); + + boolean present_lintString = true && (isSetLintString()); + builder.append(present_lintString); + if (present_lintString) + builder.append(lintString); + + boolean present_mStringString = true && (isSetMStringString()); + builder.append(present_mStringString); + if (present_mStringString) + builder.append(mStringString); + + return builder.toHashCode(); + } + + public int compareTo(Complex other) { + if (!getClass().equals(other.getClass())) { + return getClass().getName().compareTo(other.getClass().getName()); + } + + int lastComparison = 0; + Complex typedOther = (Complex)other; + + lastComparison = Boolean.valueOf(isSetAint()).compareTo(typedOther.isSetAint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aint, typedOther.aint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetAString()).compareTo(typedOther.isSetAString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetAString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.aString, typedOther.aString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLint()).compareTo(typedOther.isSetLint()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLint()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lint, typedOther.lint); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLString()).compareTo(typedOther.isSetLString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lString, typedOther.lString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetLintString()).compareTo(typedOther.isSetLintString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetLintString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.lintString, typedOther.lintString); + if (lastComparison != 0) { + return lastComparison; + } + } + lastComparison = Boolean.valueOf(isSetMStringString()).compareTo(typedOther.isSetMStringString()); + if (lastComparison != 0) { + return lastComparison; + } + if (isSetMStringString()) { + lastComparison = org.apache.thrift.TBaseHelper.compareTo(this.mStringString, typedOther.mStringString); + if (lastComparison != 0) { + return lastComparison; + } + } + return 0; + } + + public _Fields fieldForId(int fieldId) { + return _Fields.findByThriftId(fieldId); + } + + public void read(org.apache.thrift.protocol.TProtocol iprot) throws org.apache.thrift.TException { + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException { + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("Complex("); + boolean first = true; + + sb.append("aint:"); + sb.append(this.aint); + first = false; + if (!first) sb.append(", "); + sb.append("aString:"); + if (this.aString == null) { + sb.append("null"); + } else { + sb.append(this.aString); + } + first = false; + if (!first) sb.append(", "); + sb.append("lint:"); + if (this.lint == null) { + sb.append("null"); + } else { + sb.append(this.lint); + } + first = false; + if (!first) sb.append(", "); + sb.append("lString:"); + if (this.lString == null) { + sb.append("null"); + } else { + sb.append(this.lString); + } + first = false; + if (!first) sb.append(", "); + sb.append("lintString:"); + if (this.lintString == null) { + sb.append("null"); + } else { + sb.append(this.lintString); + } + first = false; + if (!first) sb.append(", "); + sb.append("mStringString:"); + if (this.mStringString == null) { + sb.append("null"); + } else { + sb.append(this.mStringString); + } + first = false; + sb.append(")"); + return sb.toString(); + } + + public void validate() throws org.apache.thrift.TException { + // check for required fields + // check for sub-struct validity + } + + private void writeObject(java.io.ObjectOutputStream out) throws java.io.IOException { + try { + write(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(out))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + try { + // it doesn't seem like you should have to do this, but java serialization is wacky, and doesn't call the default constructor. + __isset_bitfield = 0; + read(new org.apache.thrift.protocol.TCompactProtocol(new org.apache.thrift.transport.TIOStreamTransport(in))); + } catch (org.apache.thrift.TException te) { + throw new java.io.IOException(te); + } + } + + private static class ComplexStandardSchemeFactory implements SchemeFactory { + public ComplexStandardScheme getScheme() { + return new ComplexStandardScheme(); + } + } + + private static class ComplexStandardScheme extends StandardScheme { + + public void read(org.apache.thrift.protocol.TProtocol iprot, Complex struct) throws org.apache.thrift.TException { + org.apache.thrift.protocol.TField schemeField; + iprot.readStructBegin(); + while (true) + { + schemeField = iprot.readFieldBegin(); + if (schemeField.type == org.apache.thrift.protocol.TType.STOP) { + break; + } + switch (schemeField.id) { + case 1: // AINT + if (schemeField.type == org.apache.thrift.protocol.TType.I32) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 2: // A_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.STRING) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 3: // LINT + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list0 = iprot.readListBegin(); + struct.lint = new ArrayList(_list0.size); + for (int _i1 = 0; _i1 < _list0.size; ++_i1) + { + int _elem2; // required + _elem2 = iprot.readI32(); + struct.lint.add(_elem2); + } + iprot.readListEnd(); + } + struct.setLintIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 4: // L_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list3 = iprot.readListBegin(); + struct.lString = new ArrayList(_list3.size); + for (int _i4 = 0; _i4 < _list3.size; ++_i4) + { + String _elem5; // required + _elem5 = iprot.readString(); + struct.lString.add(_elem5); + } + iprot.readListEnd(); + } + struct.setLStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 5: // LINT_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.LIST) { + { + org.apache.thrift.protocol.TList _list6 = iprot.readListBegin(); + struct.lintString = new ArrayList(_list6.size); + for (int _i7 = 0; _i7 < _list6.size; ++_i7) + { + IntString _elem8; // required + _elem8 = new IntString(); + _elem8.read(iprot); + struct.lintString.add(_elem8); + } + iprot.readListEnd(); + } + struct.setLintStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + case 6: // M_STRING_STRING + if (schemeField.type == org.apache.thrift.protocol.TType.MAP) { + { + org.apache.thrift.protocol.TMap _map9 = iprot.readMapBegin(); + struct.mStringString = new HashMap(2*_map9.size); + for (int _i10 = 0; _i10 < _map9.size; ++_i10) + { + String _key11; // required + String _val12; // required + _key11 = iprot.readString(); + _val12 = iprot.readString(); + struct.mStringString.put(_key11, _val12); + } + iprot.readMapEnd(); + } + struct.setMStringStringIsSet(true); + } else { + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + break; + default: + org.apache.thrift.protocol.TProtocolUtil.skip(iprot, schemeField.type); + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + struct.validate(); + } + + public void write(org.apache.thrift.protocol.TProtocol oprot, Complex struct) throws org.apache.thrift.TException { + struct.validate(); + + oprot.writeStructBegin(STRUCT_DESC); + oprot.writeFieldBegin(AINT_FIELD_DESC); + oprot.writeI32(struct.aint); + oprot.writeFieldEnd(); + if (struct.aString != null) { + oprot.writeFieldBegin(A_STRING_FIELD_DESC); + oprot.writeString(struct.aString); + oprot.writeFieldEnd(); + } + if (struct.lint != null) { + oprot.writeFieldBegin(LINT_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, struct.lint.size())); + for (int _iter13 : struct.lint) + { + oprot.writeI32(_iter13); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lString != null) { + oprot.writeFieldBegin(L_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, struct.lString.size())); + for (String _iter14 : struct.lString) + { + oprot.writeString(_iter14); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.lintString != null) { + oprot.writeFieldBegin(LINT_STRING_FIELD_DESC); + { + oprot.writeListBegin(new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, struct.lintString.size())); + for (IntString _iter15 : struct.lintString) + { + _iter15.write(oprot); + } + oprot.writeListEnd(); + } + oprot.writeFieldEnd(); + } + if (struct.mStringString != null) { + oprot.writeFieldBegin(M_STRING_STRING_FIELD_DESC); + { + oprot.writeMapBegin(new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, struct.mStringString.size())); + for (Map.Entry _iter16 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter16.getKey()); + oprot.writeString(_iter16.getValue()); + } + oprot.writeMapEnd(); + } + oprot.writeFieldEnd(); + } + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + } + + private static class ComplexTupleSchemeFactory implements SchemeFactory { + public ComplexTupleScheme getScheme() { + return new ComplexTupleScheme(); + } + } + + private static class ComplexTupleScheme extends TupleScheme { + + @Override + public void write(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol oprot = (TTupleProtocol) prot; + BitSet optionals = new BitSet(); + if (struct.isSetAint()) { + optionals.set(0); + } + if (struct.isSetAString()) { + optionals.set(1); + } + if (struct.isSetLint()) { + optionals.set(2); + } + if (struct.isSetLString()) { + optionals.set(3); + } + if (struct.isSetLintString()) { + optionals.set(4); + } + if (struct.isSetMStringString()) { + optionals.set(5); + } + oprot.writeBitSet(optionals, 6); + if (struct.isSetAint()) { + oprot.writeI32(struct.aint); + } + if (struct.isSetAString()) { + oprot.writeString(struct.aString); + } + if (struct.isSetLint()) { + { + oprot.writeI32(struct.lint.size()); + for (int _iter17 : struct.lint) + { + oprot.writeI32(_iter17); + } + } + } + if (struct.isSetLString()) { + { + oprot.writeI32(struct.lString.size()); + for (String _iter18 : struct.lString) + { + oprot.writeString(_iter18); + } + } + } + if (struct.isSetLintString()) { + { + oprot.writeI32(struct.lintString.size()); + for (IntString _iter19 : struct.lintString) + { + _iter19.write(oprot); + } + } + } + if (struct.isSetMStringString()) { + { + oprot.writeI32(struct.mStringString.size()); + for (Map.Entry _iter20 : struct.mStringString.entrySet()) + { + oprot.writeString(_iter20.getKey()); + oprot.writeString(_iter20.getValue()); + } + } + } + } + + @Override + public void read(org.apache.thrift.protocol.TProtocol prot, Complex struct) throws org.apache.thrift.TException { + TTupleProtocol iprot = (TTupleProtocol) prot; + BitSet incoming = iprot.readBitSet(6); + if (incoming.get(0)) { + struct.aint = iprot.readI32(); + struct.setAintIsSet(true); + } + if (incoming.get(1)) { + struct.aString = iprot.readString(); + struct.setAStringIsSet(true); + } + if (incoming.get(2)) { + { + org.apache.thrift.protocol.TList _list21 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.I32, iprot.readI32()); + struct.lint = new ArrayList(_list21.size); + for (int _i22 = 0; _i22 < _list21.size; ++_i22) + { + int _elem23; // required + _elem23 = iprot.readI32(); + struct.lint.add(_elem23); + } + } + struct.setLintIsSet(true); + } + if (incoming.get(3)) { + { + org.apache.thrift.protocol.TList _list24 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.lString = new ArrayList(_list24.size); + for (int _i25 = 0; _i25 < _list24.size; ++_i25) + { + String _elem26; // required + _elem26 = iprot.readString(); + struct.lString.add(_elem26); + } + } + struct.setLStringIsSet(true); + } + if (incoming.get(4)) { + { + org.apache.thrift.protocol.TList _list27 = new org.apache.thrift.protocol.TList(org.apache.thrift.protocol.TType.STRUCT, iprot.readI32()); + struct.lintString = new ArrayList(_list27.size); + for (int _i28 = 0; _i28 < _list27.size; ++_i28) + { + IntString _elem29; // required + _elem29 = new IntString(); + _elem29.read(iprot); + struct.lintString.add(_elem29); + } + } + struct.setLintStringIsSet(true); + } + if (incoming.get(5)) { + { + org.apache.thrift.protocol.TMap _map30 = new org.apache.thrift.protocol.TMap(org.apache.thrift.protocol.TType.STRING, org.apache.thrift.protocol.TType.STRING, iprot.readI32()); + struct.mStringString = new HashMap(2*_map30.size); + for (int _i31 = 0; _i31 < _map30.size; ++_i31) + { + String _key32; // required + String _val33; // required + _key32 = iprot.readString(); + _val33 = iprot.readString(); + struct.mStringString.put(_key32, _val33); + } + } + struct.setMStringStringIsSet(true); + } + } + } + +} + diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 64d1ce92931e..15c2c3deb0d8 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -90,8 +90,10 @@ public void setUp() throws IOException { @After public void tearDown() throws IOException { // Clean up tables. - sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); - sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + if (sqlContext != null) { + sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); + sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + } } @Test diff --git a/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf b/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf new file mode 100644 index 000000000000..d00491fd7e5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/! operator-0-ee7f6a60a9792041b85b18cda56429bf @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 index d35bf9093ca9..2383bef94097 100644 --- a/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 +++ b/sql/hive/src/test/resources/golden/convert_enum_to_string-1-db089ff46f9826c7883198adacdfad59 @@ -15,9 +15,9 @@ my_enum_structlist_map map from deserializer my_structlist array>> from deserializer my_enumlist array from deserializer -my_stringset struct<> from deserializer -my_enumset struct<> from deserializer -my_structset struct<> from deserializer +my_stringset array from deserializer +my_enumset array from deserializer +my_structset array>> from deserializer optionals struct<> from deserializer b string diff --git a/sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 b/sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 similarity index 100% rename from sql/hive/src/test/resources/golden/parenthesis_star_by-5-6888c7f7894910538d82eefa23443189 rename to sql/hive/src/test/resources/golden/parenthesis_star_by-5-41d474f5e6d7c61c36f74b4bec4e9e44 diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 index 501bb6ab32f2..7bb2c0ab4398 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` smallint, `value` float) COMMENT 'temporary table' diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 index 90f8415a1c6b..3cc1a57ee3a4 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_feng.tmp_showcrt`( +CREATE TABLE `tmp_feng.tmp_showcrt`( `key` string, `value` int) ROW FORMAT SERDE diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 index 4ee22e523031..b51c71a71f91 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 index 6fda2570b53f..29189e1d860a 100644 --- a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 @@ -1,4 +1,4 @@ -CREATE TABLE `tmp_showcrt1`( +CREATE TABLE `tmp_showcrt1`( `key` int, `value` string, `newvalue` bigint) diff --git a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 index 3049cd6243ad..1b283db3e774 100644 --- a/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 +++ b/sql/hive/src/test/resources/golden/show_functions-0-45a7762c39f1b0f26f076220e2764043 @@ -17,6 +17,7 @@ ^ abs acos +add_months and array array_contains @@ -29,6 +30,7 @@ base64 between bin case +cbrt ceil ceiling coalesce @@ -47,7 +49,11 @@ covar_samp create_union cume_dist current_database +current_date +current_timestamp +current_user date_add +date_format date_sub datediff day @@ -65,6 +71,7 @@ ewah_bitmap_empty ewah_bitmap_or exp explode +factorial field find_in_set first_value @@ -73,6 +80,7 @@ format_number from_unixtime from_utc_timestamp get_json_object +greatest hash hex histogram_numeric @@ -81,6 +89,7 @@ if in in_file index +initcap inline instr isnotnull @@ -88,10 +97,13 @@ isnull java_method json_tuple lag +last_day last_value lcase lead +least length +levenshtein like ln locate @@ -109,11 +121,15 @@ max min minute month +months_between named_struct negative +next_day ngrams noop +noopstreaming noopwithmap +noopwithmapstreaming not ntile nvl @@ -147,10 +163,14 @@ rpad rtrim second sentences +shiftleft +shiftright +shiftrightunsigned sign sin size sort_array +soundex space split sqrt @@ -170,6 +190,7 @@ to_unix_timestamp to_utc_timestamp translate trim +trunc ucase unbase64 unhex diff --git a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae index 0f6cc6f44f1f..fdf701f96280 100644 --- a/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae +++ b/sql/hive/src/test/resources/golden/show_tblproperties-1-be4adb893c7f946ebd76a648ce3cc1ae @@ -1 +1 @@ -Table tmpfoo does not have property: bar +Table default.tmpfoo does not have property: bar diff --git a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 index 3c91e138d7bd..d8ec084f0b2b 100644 --- a/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 +++ b/sql/hive/src/test/resources/golden/udf_date_add-1-efb60fcbd6d78ad35257fb1ec39ace2 @@ -1,5 +1,5 @@ date_add(start_date, num_days) - Returns the date that is num_days after start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_add('2009-30-07', 1) FROM src LIMIT 1; - '2009-31-07' + > SELECT date_add('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-31' diff --git a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 index 29d663f35c58..169c50003625 100644 --- a/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 +++ b/sql/hive/src/test/resources/golden/udf_date_sub-1-7efeb74367835ade71e5e42b22f8ced4 @@ -1,5 +1,5 @@ date_sub(start_date, num_days) - Returns the date that is num_days before start_date. start_date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. num_days is a number. The time part of start_date is ignored. Example: - > SELECT date_sub('2009-30-07', 1) FROM src LIMIT 1; - '2009-29-07' + > SELECT date_sub('2009-07-30', 1) FROM src LIMIT 1; + '2009-07-29' diff --git a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 index 7ccaee7ad3bd..42197f7ad3e5 100644 --- a/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 +++ b/sql/hive/src/test/resources/golden/udf_datediff-1-34ae7a68b13c2bc9a89f61acf2edd4c5 @@ -1,5 +1,5 @@ datediff(date1, date2) - Returns the number of days between date1 and date2 date1 and date2 are strings in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. The time parts are ignored.If date1 is earlier than date2, the result is negative. Example: - > SELECT datediff('2009-30-07', '2009-31-07') FROM src LIMIT 1; + > SELECT datediff('2009-07-30', '2009-07-31') FROM src LIMIT 1; 1 diff --git a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 index d4017178b4e6..09703d10eab7 100644 --- a/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 +++ b/sql/hive/src/test/resources/golden/udf_day-0-c4c503756384ff1220222d84fd25e756 @@ -1 +1 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 index 6135aafa5086..7c0ec1dc3be5 100644 --- a/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 +++ b/sql/hive/src/test/resources/golden/udf_day-1-87168babe1110fe4c38269843414ca4 @@ -1,6 +1,9 @@ -day(date) - Returns the date of the month of date +day(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: dayofmonth -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT day('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT day('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 index 47a7018d9d5a..c37eb0ec2e96 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-0-7b2caf942528656555cf19c261a18502 @@ -1 +1 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval diff --git a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 index d9490e20a3b6..9e931f649914 100644 --- a/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 +++ b/sql/hive/src/test/resources/golden/udf_dayofmonth-1-ca24d07102ad264d79ff30c64a73a7e8 @@ -1,6 +1,9 @@ -dayofmonth(date) - Returns the date of the month of date +dayofmonth(param) - Returns the day of the month of date/timestamp, or day component of interval Synonyms: day -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. -Example: - > SELECT dayofmonth('2009-30-07', 1) FROM src LIMIT 1; +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'. +2. A date value +3. A timestamp value +4. A day-time interval valueExample: + > SELECT dayofmonth('2009-07-30') FROM src LIMIT 1; 30 diff --git a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-0-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-1-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c +++ b/sql/hive/src/test/resources/golden/udf_if-1-b7ffa85b5785cccef2af1b285348cc2c @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a index 2cf0d9d61882..ce583fe81ff6 100644 --- a/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a +++ b/sql/hive/src/test/resources/golden/udf_if-2-30cf7f51f92b5684e556deff3032d49a @@ -1 +1 @@ -There is no documentation for function 'if' +IF(expr1,expr2,expr3) - If expr1 is TRUE (expr1 <> 0 and expr1 <> NULL) then IF() returns expr2; otherwise it returns expr3. IF() returns a numeric or string value, depending on the context in which it is used. diff --git a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566..06650592f8d3 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_minute-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae..08ddc19b84d8 100644 --- a/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_minute-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee index 231e4f382566..06650592f8d3 100644 --- a/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee +++ b/sql/hive/src/test/resources/golden/udf_month-0-9a38997c1f41f4afe00faa0abc471aee @@ -1 +1 @@ -minute(date) - Returns the minute of date +minute(param) - Returns the minute component of the string/timestamp/interval diff --git a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 index ea842ea174ae..08ddc19b84d8 100644 --- a/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 +++ b/sql/hive/src/test/resources/golden/udf_month-1-16995573ac4f4a1b047ad6ee88699e48 @@ -1,6 +1,8 @@ -minute(date) - Returns the minute of date -date is a string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. -Example: +minute(param) - Returns the minute component of the string/timestamp/interval +param can be one of: +1. A string in the format of 'yyyy-MM-dd HH:mm:ss' or 'HH:mm:ss'. +2. A timestamp value +3. A day-time interval valueExample: > SELECT minute('2009-07-30 12:58:59') FROM src LIMIT 1; 58 > SELECT minute('12:58:59') FROM src LIMIT 1; diff --git a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 index d54ebfbd6fb1..a529b107ff21 100644 --- a/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 +++ b/sql/hive/src/test/resources/golden/udf_std-1-6759bde0e50a3607b7c3fd5a93cbd027 @@ -1,2 +1,2 @@ std(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, stddev +Synonyms: stddev, stddev_pop diff --git a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d index 5f674788180e..ac3176a38254 100644 --- a/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d +++ b/sql/hive/src/test/resources/golden/udf_stddev-1-18e1d598820013453fad45852e1a303d @@ -1,2 +1,2 @@ stddev(x) - Returns the standard deviation of a set of numbers -Synonyms: stddev_pop, std +Synonyms: std, stddev_pop diff --git a/sql/hive/src/test/resources/golden/union3-0-6a8a35102de1b0b88c6721a704eb174d b/sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c similarity index 100% rename from sql/hive/src/test/resources/golden/union3-0-6a8a35102de1b0b88c6721a704eb174d rename to sql/hive/src/test/resources/golden/union3-0-99620f72f0282904846a596ca5b3e46c diff --git a/sql/hive/src/test/resources/golden/union3-2-2a1dcd937f117f1955a169592b96d5f9 b/sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 similarity index 100% rename from sql/hive/src/test/resources/golden/union3-2-2a1dcd937f117f1955a169592b96d5f9 rename to sql/hive/src/test/resources/golden/union3-2-90ca96ea59fd45cf0af8c020ae77c908 diff --git a/sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 b/sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 similarity index 100% rename from sql/hive/src/test/resources/golden/union3-3-8fc63f8edb2969a63cd4485f1867ba97 rename to sql/hive/src/test/resources/golden/union3-3-72b149ccaef751bcfe55d5ca37cb5fd7 diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q index 9e036c1a91d3..e911fbf2d2c5 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/parenthesis_star_by.q @@ -5,6 +5,6 @@ SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY key, value)t ORDER BY ke SELECT key, value FROM src CLUSTER BY (key, value); -SELECT key, value FROM src ORDER BY (key ASC, value ASC); +SELECT key, value FROM src ORDER BY key ASC, value ASC; SELECT key, value FROM src SORT BY (key, value); SELECT * FROM (SELECT key, value FROM src DISTRIBUTE BY (key, value))t ORDER BY key, value; diff --git a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q index b26a2e2799f7..a989800cbf85 100644 --- a/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q +++ b/sql/hive/src/test/resources/ql/src/test/queries/clientpositive/union3.q @@ -1,42 +1,41 @@ +-- SORT_QUERY_RESULTS explain SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; CREATE TABLE union_out (id int); -insert overwrite table union_out +insert overwrite table union_out SELECT * FROM ( SELECT 1 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 2 AS id FROM (SELECT * FROM src LIMIT 1) s1 - CLUSTER BY id UNION ALL SELECT 3 AS id FROM (SELECT * FROM src LIMIT 1) s2 UNION ALL SELECT 4 AS id FROM (SELECT * FROM src LIMIT 1) s2 + CLUSTER BY id ) a; -select * from union_out cluster by id; +select * from union_out; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala new file mode 100644 index 000000000000..34b2edb44b03 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.net.URL + +import org.apache.spark.SparkFunSuite + +/** + * Verify that some classes load and that others are not found on the classpath. + * + * + * This is used to detect classpath and shading conflict, especially between + * Spark's required Kryo version and that which can be found in some Hive versions. + */ +class ClasspathDependenciesSuite extends SparkFunSuite { + private val classloader = this.getClass.getClassLoader + + private def assertLoads(classname: String): Unit = { + val resourceURL: URL = Option(findResource(classname)).getOrElse { + fail(s"Class $classname not found as ${resourceName(classname)}") + } + + logInfo(s"Class $classname at $resourceURL") + classloader.loadClass(classname) + } + + private def assertLoads(classes: String*): Unit = { + classes.foreach(assertLoads) + } + + private def findResource(classname: String): URL = { + val resource = resourceName(classname) + classloader.getResource(resource) + } + + private def resourceName(classname: String): String = { + classname.replace(".", "/") + ".class" + } + + private def assertClassNotFound(classname: String): Unit = { + Option(findResource(classname)).foreach { resourceURL => + fail(s"Class $classname found at $resourceURL") + } + + intercept[ClassNotFoundException] { + classloader.loadClass(classname) + } + } + + private def assertClassNotFound(classes: String*): Unit = { + classes.foreach(assertClassNotFound) + } + + private val KRYO = "com.esotericsoftware.kryo.Kryo" + + private val SPARK_HIVE = "org.apache.hive." + private val SPARK_SHADED = "org.spark-project.hive.shaded." + + test("shaded Protobuf") { + assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException") + } + + test("hive-common") { + assertLoads("org.apache.hadoop.hive.conf.HiveConf") + } + + test("hive-exec") { + assertLoads("org.apache.hadoop.hive.ql.CommandNeedRetryException") + } + + private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy" + + test("unshaded kryo") { + assertLoads(KRYO, STD_INSTANTIATOR) + } + + test("Forbidden Dependencies") { + assertClassNotFound( + SPARK_HIVE + KRYO, + SPARK_SHADED + KRYO, + "org.apache.hive." + KRYO, + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, + "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR + ) + } + + test("parquet-hadoop-bundle") { + assertLoads( + "parquet.hadoop.ParquetOutputFormat", + "parquet.hadoop.ParquetInputFormat" + ) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 72b35959a491..b8d41065d3f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -19,8 +19,11 @@ package org.apache.spark.sql.hive import java.io.File +import scala.collection.mutable.ArrayBuffer import scala.sys.process.{ProcessLogger, Process} +import org.scalatest.exceptions.TestFailedDueToTimeoutException + import org.apache.spark._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -84,23 +87,39 @@ class HiveSparkSubmitSuite // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val history = ArrayBuffer.empty[String] + val commands = Seq("./bin/spark-submit") ++ args + val commandLine = commands.mkString("'", "' '", "'") val process = Process( - Seq("./bin/spark-submit") ++ args, + commands, new File(sparkHome), "SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome ).run(ProcessLogger( // scalastyle:off println - (line: String) => { println(s"out> $line") }, - (line: String) => { println(s"err> $line") } + (line: String) => { println(s"stdout> $line"); history += s"out> $line"}, + (line: String) => { println(s"stderr> $line"); history += s"err> $line" } // scalastyle:on println )) try { - val exitCode = failAfter(180 seconds) { process.exitValue() } + val exitCode = failAfter(180.seconds) { process.exitValue() } if (exitCode != 0) { - fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + // include logs in output. Note that logging is async and may not have completed + // at the time this exception is raised + Thread.sleep(1000) + val historyLog = history.mkString("\n") + fail(s"$commandLine returned with exit code $exitCode." + + s" See the log4j logs for more detail." + + s"\n$historyLog") } + } catch { + case to: TestFailedDueToTimeoutException => + val historyLog = history.mkString("\n") + fail(s"Timeout of $commandLine" + + s" See the log4j logs for more detail." + + s"\n$historyLog", to) + case t: Throwable => throw t } finally { // Ensure we still kill the process in case it timed out process.destroy() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 508695919e9a..d33e81227db8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter import org.apache.spark.sql.execution.QueryExecutionException @@ -113,6 +114,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4203:random partition directory order") { sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Utils.createTempDir() + val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) + sql( s""" |CREATE TABLE table_with_partition(c1 string) @@ -145,7 +148,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { """.stripMargin) def listFolders(path: File, acc: List[String]): List[List[String]] = { val dir = path.listFiles() - val folders = dir.filter(_.isDirectory).toList + val folders = dir.filter { e => e.isDirectory && !e.getName().startsWith(stagingDir) }.toList if (folders.isEmpty) { List(acc.reverse) } else { @@ -158,7 +161,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) - assert(listFolders(tmpDir, List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + assert(listFolders(tmpDir, List()).sortBy(_.toString()) === expected.sortBy(_.toString)) sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index bb5f1febe9ad..f00d3754c364 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import org.apache.hadoop.hive.conf.HiveConf + import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.parquet.ParquetCompatibilityTest import org.apache.spark.sql.{Row, SQLConf, SQLContext} @@ -26,6 +28,13 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { override val sqlContext: SQLContext = TestHive + /** + * Set the staging directory (and hence path to ignore Parquet files under) + * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. + */ + override val stagingDir: Option[String] = + Some(new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR)) + override protected def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index bc72b0172a46..e4fec7e2c8a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -54,6 +54,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { } } + // Ensure session state is initialized. + ctx.parseSql("use default") + assertAnalyzeCommand( "ANALYZE TABLE Table1 COMPUTE STATISTICS", classOf[HiveNativeCommand]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 3eb127e23d48..f0bb77092c0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.client import java.io.File +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly @@ -48,7 +49,9 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion, + buildConf(), + ivyPath).client val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -91,6 +94,7 @@ class VersionsSuite extends SparkFunSuite with Logging { versions.foreach { version => test(s"$version: create client") { client = null + System.gc() // Hack to avoid SEGV on some JVM versions. client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 11a843becce6..a7cfac51cc09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -52,14 +52,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) } override def afterAll() { @@ -69,15 +61,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TEMPORARY FUNCTION udtf_count2") } - createQueryTest("Test UDTF.close in Lateral Views", - """ - |SELECT key, cc - |FROM src LATERAL VIEW udtf_count2(value) dd AS cc - """.stripMargin, false) // false mean we have to keep the temp function in registry - - createQueryTest("Test UDTF.close in SELECT", - "SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) table", false) - test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => sql("USE default") @@ -176,8 +159,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("! operator", """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table + | SELECT 1 AS a UNION ALL SELECT 2 AS a) t |WHERE !(a>1) """.stripMargin) @@ -229,71 +211,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |FROM src LIMIT 1; """.stripMargin) - createQueryTest("count distinct 0 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 0) table - """.stripMargin) - - createQueryTest("count distinct 1 value strings", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 'a' AS a FROM src LIMIT 1 UNION ALL - | SELECT 'b' AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 2 AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values including null", - """ - |SELECT COUNT(DISTINCT a, 1) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT 1 AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 2 values long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 2L AS a FROM src LIMIT 1) table - """.stripMargin) - - createQueryTest("count distinct 1 value + null long", - """ - |SELECT COUNT(DISTINCT a) FROM ( - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT 1L AS a FROM src LIMIT 1 UNION ALL - | SELECT null AS a FROM src LIMIT 1) table - """.stripMargin) - createQueryTest("null case", "SELECT case when(true) then 1 else null end FROM src LIMIT 1") @@ -674,7 +591,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql( """ |SELECT a FROM ( - | SELECT 1 AS a FROM src LIMIT 1 ) table + | SELECT 1 AS a FROM src LIMIT 1 ) t |WHERE abs(20141202) is not null """.stripMargin).collect() } @@ -987,7 +904,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { .zip(parts) .map { case (k, v) => if (v == "NULL") { - s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultStrVal}" } else { s"$k=$v" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index e83a7dc77e32..3bf8f3ac2048 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -82,16 +82,16 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { Seq.empty) createPruningTest("Column pruning - non-trivial top project with aliases", - "SELECT c1 * 2 AS double FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", - Seq("double"), + "SELECT c1 * 2 AS dbl FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("dbl"), Seq("key"), Seq.empty) // Partition pruning tests createPruningTest("Partition pruning - non-partitioned, non-trivial project", - "SELECT key * 2 AS double FROM src WHERE value IS NOT NULL", - Seq("double"), + "SELECT key * 2 AS dbl FROM src WHERE value IS NOT NULL", + Seq("dbl"), Seq("key", "value"), Seq.empty) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c4923d83e48f..95c1da6e9796 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -67,6 +67,25 @@ class MyDialect extends DefaultParserDialect class SQLQuerySuite extends QueryTest with SQLTestUtils { override def sqlContext: SQLContext = TestHive + test("UDTF") { + sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") @@ -264,47 +283,51 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { setConf(HiveContext.CONVERT_CTAS, true) - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - var message = intercept[AnalysisException] { + try { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("ctas1 already exists")) - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - // Specifying database name for query can be converted to data source write path - // is not allowed right now. - message = intercept[AnalysisException] { - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert( - message.contains("Cannot specify database name in a CTAS statement"), - "When spark.sql.hive.convertCTAS is true, we should not allow " + - "database name specified.") - - sql("CREATE TABLE ctas1 stored as textfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - sql( - "CREATE TABLE ctas1 stored as sequencefile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) - sql("DROP TABLE ctas1") - - setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + var message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("ctas1 already exists")) + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + // Specifying database name for query can be converted to data source write path + // is not allowed right now. + message = intercept[AnalysisException] { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert( + message.contains("Cannot specify database name in a CTAS statement"), + "When spark.sql.hive.convertCTAS is true, we should not allow " + + "database name specified.") + + sql("CREATE TABLE ctas1 stored as textfile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + + sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false) + sql("DROP TABLE ctas1") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE IF EXISTS ctas1") + } } test("SQL Dialect Switching") { @@ -670,22 +693,25 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val originalConf = convertCTAS setConf(HiveContext.CONVERT_CTAS, false) - sql("CREATE TABLE explodeTest (key bigInt)") - table("explodeTest").queryExecution.analyzed match { - case metastoreRelation: MetastoreRelation => // OK - case _ => - fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") - } + try { + sql("CREATE TABLE explodeTest (key bigInt)") + table("explodeTest").queryExecution.analyzed match { + case metastoreRelation: MetastoreRelation => // OK + case _ => + fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") + } - sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") - checkAnswer( - sql("SELECT key from explodeTest"), - (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) - ) + sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data") + checkAnswer( + sql("SELECT key from explodeTest"), + (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil) + ) - sql("DROP TABLE explodeTest") - dropTempTable("data") - setConf(HiveContext.CONVERT_CTAS, originalConf) + sql("DROP TABLE explodeTest") + dropTempTable("data") + } finally { + setConf(HiveContext.CONVERT_CTAS, originalConf) + } } test("sanity test for SPARK-6618") { @@ -1058,12 +1084,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { val df = TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) - df.toDF("id", "date").registerTempTable("test_SPARK8588") + df.toDF("id", "datef").registerTempTable("test_SPARK8588") checkAnswer( TestHive.sql( """ - |select id, concat(year(date)) - |from test_SPARK8588 where concat(year(date), ' year') in ('2015 year', '2014 year') + |select id, concat(year(datef)) + |from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year') """.stripMargin), Row(1, "2014") :: Row(2, "2015") :: Nil ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index af3f468aaa5e..deec0048d24b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -48,11 +48,9 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - load( - source = dataSourceName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) + read.options(Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index d463e8fd626f..a46ca9a2c970 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -31,7 +31,6 @@ import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag - // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -40,7 +39,7 @@ case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: St // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { - val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index f56fb96c52d3..c4bc60086f6e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -60,7 +60,14 @@ case class ParquetDataWithKeyAndComplexTypes( class ParquetMetastoreSuite extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() - + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") sql(s""" create external table partitioned_parquet ( @@ -172,14 +179,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { - sql("DROP TABLE partitioned_parquet") - sql("DROP TABLE partitioned_parquet_with_key") - sql("DROP TABLE partitioned_parquet_with_complextypes") - sql("DROP TABLE partitioned_parquet_with_key_and_complextypes") - sql("DROP TABLE normal_parquet") - sql("DROP TABLE IF EXISTS jt") - sql("DROP TABLE IF EXISTS jt_array") - sql("DROP TABLE IF EXISTS test_parquet") + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } @@ -203,6 +210,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } test("insert into an empty parquet table") { + dropTables("test_insert_parquet") sql( """ |create table test_insert_parquet @@ -228,7 +236,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"), Row(3, "str3") :: Row(4, "str4") :: Nil ) - sql("DROP TABLE IF EXISTS test_insert_parquet") + dropTables("test_insert_parquet") // Create it again. sql( @@ -255,118 +263,118 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql(s"SELECT intField, stringField FROM test_insert_parquet"), (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i")) ) - sql("DROP TABLE IF EXISTS test_insert_parquet") + dropTables("test_insert_parquet") } test("scan a parquet table created through a CTAS statement") { - sql( - """ - |create table test_parquet_ctas ROW FORMAT - |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - |AS select * from jt - """.stripMargin) + withTable("test_parquet_ctas") { + sql( + """ + |create table test_parquet_ctas ROW FORMAT + |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |AS select * from jt + """.stripMargin) - checkAnswer( - sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), - Seq(Row(1, "str1")) - ) + checkAnswer( + sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"), + Seq(Row(1, "str1")) + ) - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: ParquetRelation) => // OK - case _ => fail( - "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation].getCanonicalName}") + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(_: ParquetRelation) => // OK + case _ => fail( + "test_parquet_ctas should be converted to " + + s"${classOf[ParquetRelation].getCanonicalName }") + } } - - sql("DROP TABLE IF EXISTS test_parquet_ctas") } test("MetastoreRelation in InsertIntoTable will be converted") { - sql( - """ - |create table test_insert_parquet - |( - | intField INT - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | intField INT + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + + s"However, found a ${o.toString} ") + } - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK - case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + - s"However, found a ${o.toString} ") + checkAnswer( + sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), + sql("SELECT a FROM jt WHERE jt.a > 5").collect() + ) } - - checkAnswer( - sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), - sql("SELECT a FROM jt WHERE jt.a > 5").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") } test("MetastoreRelation in InsertIntoHiveTable will be converted") { - sql( - """ - |create table test_insert_parquet - |( - | int_array array - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("test_insert_parquet") { + sql( + """ + |create table test_insert_parquet + |( + | int_array array + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") + df.queryExecution.executedPlan match { + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK + case o => fail("test_insert_parquet should be converted to a " + + s"${classOf[ParquetRelation].getCanonicalName} and " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"However, found a ${o.toString} ") + } - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK - case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + - s"However, found a ${o.toString} ") + checkAnswer( + sql("SELECT int_array FROM test_insert_parquet"), + sql("SELECT a FROM jt_array").collect() + ) } - - checkAnswer( - sql("SELECT int_array FROM test_insert_parquet"), - sql("SELECT a FROM jt_array").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") } test("SPARK-6450 regression test") { - sql( - """CREATE TABLE IF NOT EXISTS ms_convert (key INT) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) + withTable("ms_convert") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed - // This shouldn't throw AnalysisException - val analyzed = sql( - """SELECT key FROM ms_convert - |UNION ALL - |SELECT key FROM ms_convert - """.stripMargin).queryExecution.analyzed - - assertResult(2) { - analyzed.collect { - case r @ LogicalRelation(_: ParquetRelation) => r - }.size + assertResult(2) { + analyzed.collect { + case r@LogicalRelation(_: ParquetRelation) => r + }.size + } } - - sql("DROP TABLE ms_convert") } def collectParquetRelation(df: DataFrame): ParquetRelation = { @@ -379,42 +387,42 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { - sql( - s"""CREATE TABLE nonPartitioned ( - | key INT, - | value STRING - |) - |STORED AS PARQUET - """.stripMargin) - - // First lookup fills the cache - val r1 = collectParquetRelation(table("nonPartitioned")) - // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("nonPartitioned")) - // They should be the same instance - assert(r1 eq r2) - - sql("DROP TABLE nonPartitioned") + withTable("nonPartitioned") { + sql( + s"""CREATE TABLE nonPartitioned ( + | key INT, + | value STRING + |) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("nonPartitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("nonPartitioned")) + // They should be the same instance + assert(r1 eq r2) + } } test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { - sql( - s"""CREATE TABLE partitioned ( - | key INT, - | value STRING - |) - |PARTITIONED BY (part INT) - |STORED AS PARQUET + withTable("partitioned") { + sql( + s"""CREATE TABLE partitioned ( + | key INT, + | value STRING + |) + |PARTITIONED BY (part INT) + |STORED AS PARQUET """.stripMargin) - // First lookup fills the cache - val r1 = collectParquetRelation(table("partitioned")) - // Second lookup should reuse the cache - val r2 = collectParquetRelation(table("partitioned")) - // They should be the same instance - assert(r1 eq r2) - - sql("DROP TABLE partitioned") + // First lookup fills the cache + val r1 = collectParquetRelation(table("partitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("partitioned")) + // They should be the same instance + assert(r1 eq r2) + } } test("Caching converted data source Parquet Relations") { @@ -430,8 +438,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } - sql("DROP TABLE IF EXISTS test_insert_parquet") - sql("DROP TABLE IF EXISTS test_parquet_partitioned_cache_test") + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") sql( """ @@ -479,7 +486,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { | intField INT, | stringField STRING |) - |PARTITIONED BY (date string) + |PARTITIONED BY (`date` string) |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' |STORED AS | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' @@ -491,7 +498,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test - |PARTITION (date='2015-04-01') + |PARTITION (`date`='2015-04-01') |select a, b from jt """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. @@ -500,7 +507,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test - |PARTITION (date='2015-04-02') + |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) @@ -510,7 +517,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( - sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), + sql("select STRINGField, `date`, intField from test_parquet_partitioned_cache_test"), sql( """ |select b, '2015-04-01', a FROM jt @@ -521,8 +528,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { invalidateTable("test_parquet_partitioned_cache_test") assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) - sql("DROP TABLE test_insert_parquet") - sql("DROP TABLE test_parquet_partitioned_cache_test") + dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } } @@ -532,6 +538,11 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { class ParquetSourceSuite extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet") sql( s""" create temporary table partitioned_parquet @@ -635,22 +646,22 @@ class ParquetSourceSuite extends ParquetPartitioningTest { StructField("a", arrayType1, nullable = true) :: Nil) assert(df.schema === expectedSchema1) - df.write.format("parquet").saveAsTable("alwaysNullable") + withTable("alwaysNullable") { + df.write.format("parquet").saveAsTable("alwaysNullable") - val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) - val arrayType2 = ArrayType(IntegerType, containsNull = true) - val expectedSchema2 = - StructType( - StructField("m", mapType2, nullable = true) :: - StructField("a", arrayType2, nullable = true) :: Nil) + val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) + val arrayType2 = ArrayType(IntegerType, containsNull = true) + val expectedSchema2 = + StructType( + StructField("m", mapType2, nullable = true) :: + StructField("a", arrayType2, nullable = true) :: Nil) - assert(table("alwaysNullable").schema === expectedSchema2) - - checkAnswer( - sql("SELECT m, a FROM alwaysNullable"), - Row(Map(2 -> 3), Seq(4, 5, 6))) + assert(table("alwaysNullable").schema === expectedSchema2) - sql("DROP TABLE alwaysNullable") + checkAnswer( + sql("SELECT m, a FROM alwaysNullable"), + Row(Map(2 -> 3), Seq(4, 5, 6))) + } } test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { @@ -738,6 +749,16 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with partitionedTableDirWithKeyAndComplexTypes.delete() } + /** + * Drop named tables if they exist + * @param tableNames tables to drop + */ + def dropTables(tableNames: String*): Unit = { + tableNames.foreach { name => + sql(s"DROP TABLE IF EXISTS $name") + } + } + Seq( "partitioned_parquet", "partitioned_parquet_with_key", diff --git a/yarn/pom.xml b/yarn/pom.xml index 2aeed98285aa..49360c48256e 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -30,7 +30,6 @@ Spark Project YARN yarn - 1.9 @@ -125,25 +124,16 @@ com.sun.jersey jersey-core - ${jersey.version} test com.sun.jersey jersey-json - ${jersey.version} test - - - stax - stax-api - - com.sun.jersey jersey-server - ${jersey.version} test diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 547863d9a073..eb6e1fd37062 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -384,19 +384,29 @@ private object YarnClusterDriver extends Logging with Matchers { } -private object YarnClasspathTest { +private object YarnClasspathTest extends Logging { + + var exitCode = 0 + + def error(m: String, ex: Throwable = null): Unit = { + logError(m, ex) + // scalastyle:off println + System.out.println(m) + if (ex != null) { + ex.printStackTrace(System.out) + } + // scalastyle:on println + } def main(args: Array[String]): Unit = { if (args.length != 2) { - // scalastyle:off println - System.err.println( + error( s""" |Invalid command line: ${args.mkString(" ")} | |Usage: YarnClasspathTest [driver result file] [executor result file] """.stripMargin) // scalastyle:on println - System.exit(1) } readResource(args(0)) @@ -406,6 +416,7 @@ private object YarnClasspathTest { } finally { sc.stop() } + System.exit(exitCode) } private def readResource(resultPath: String): Unit = { @@ -415,6 +426,11 @@ private object YarnClasspathTest { val resource = ccl.getResourceAsStream("test.resource") val bytes = ByteStreams.toByteArray(resource) result = new String(bytes, 0, bytes.length, UTF_8) + } catch { + case t: Throwable => + error(s"loading test.resource to $resultPath", t) + // set the exit code if not yet set + exitCode = 2 } finally { Files.write(result, new File(resultPath), UTF_8) } From 13675c742a71cbdc8324701c3694775ce1dd5c62 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 3 Aug 2015 16:44:25 -0700 Subject: [PATCH 0011/1824] [SPARK-8874] [ML] Add missing methods in Word2Vec Add missing methods 1. getVectors 2. findSynonyms to W2Vec scala and python API mengxr Author: MechCoder Closes #7263 from MechCoder/missing_methods_w2vec and squashes the following commits: 149d5ca [MechCoder] minor doc 69d91b7 [MechCoder] [SPARK-8874] [ML] Add missing methods in Word2Vec --- .../apache/spark/ml/feature/Word2Vec.scala | 38 +++++++++++- .../spark/ml/feature/Word2VecSuite.scala | 62 +++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 6ea659095630..b4f46cef798d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -18,15 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental +import org.apache.spark.SparkContext import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types._ /** @@ -146,6 +148,40 @@ class Word2VecModel private[ml] ( wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { + + /** + * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and + * and the vector the DenseVector that it is mapped to. + */ + val getVectors: DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble))) + sc.parallelize(wordVec.toSeq).toDF("word", "vector") + } + + /** + * Find "num" number of words closest in similarity to the given word. + * Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word. + */ + def findSynonyms(word: String, num: Int): DataFrame = { + findSynonyms(wordVectors.transform(word), num) + } + + /** + * Find "num" number of words closest to similarity to the given vector representation + * of the word. Returns a dataframe with the words and the cosine similarities between the + * synonyms and the given word vector. + */ + def findSynonyms(word: Vector, num: Int): DataFrame = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + } + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index aa6ce533fd88..adcda0e623b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") } } + + test("getVectors") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) + ) + val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) } + + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val realVectors = model.getVectors.sort("word").select("vector").map { + case Row(v: Vector) => v + }.collect() + + realVectors.zip(expectedVectors).foreach { + case (real, expected) => + assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") + } + } + + test("findSynonyms") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644) + val (synonyms, similarity) = model.findSynonyms("a", 2).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + assert(synonyms.toArray === Array("b", "c")) + expectedSimilarity.zip(similarity).map { + case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + } + + } } From 7abaaad5b169520fbf7299808b2bafde089a16a2 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Aug 2015 17:00:59 -0700 Subject: [PATCH 0012/1824] Add a prerequisites section for building docs This puts all the install commands that need to be run in one section instead of being spread over many paragraphs cc rxin Author: Shivaram Venkataraman Closes #7912 from shivaram/docs-setup-readme and squashes the following commits: cf7a204 [Shivaram Venkataraman] Add a prerequisites section for building docs --- docs/README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/README.md b/docs/README.md index d7652e921f7d..50209896f986 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,16 @@ Read on to learn more about viewing documentation in plain text (i.e., markdown) documentation yourself. Why build it yourself? So that you have the docs that corresponds to whichever version of Spark you currently have checked out of revision control. +## Prerequisites +The Spark documenation build uses a number of tools to build HTML docs and API docs in Scala, Python +and R. To get started you can run the following commands + + $ sudo gem install jekyll + $ sudo gem install jekyll-redirect-from + $ sudo pip install Pygments + $ Rscript -e 'install.packages(c("knitr", "devtools"), repos="http://cran.stat.ucla.edu/")' + + ## Generating the Documentation HTML We include the Spark documentation as part of the source (as opposed to using a hosted wiki, such as From b79b4f5f2251ed7efeec1f4b26e45a8ea6b85a6a Mon Sep 17 00:00:00 2001 From: Matthew Brandyberry Date: Mon, 3 Aug 2015 17:36:56 -0700 Subject: [PATCH 0013/1824] [SPARK-9483] Fix UTF8String.getPrefix for big-endian. Previous code assumed little-endian. Author: Matthew Brandyberry Closes #7902 from mtbrandy/SPARK-9483 and squashes the following commits: ec31df8 [Matthew Brandyberry] [SPARK-9483] Changes from review comments. 17d54c6 [Matthew Brandyberry] [SPARK-9483] Fix UTF8String.getPrefix for big-endian. --- .../apache/spark/unsafe/types/UTF8String.java | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f6c9b87778f8..d80bd57bd204 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -20,6 +20,7 @@ import javax.annotation.Nonnull; import java.io.Serializable; import java.io.UnsupportedEncodingException; +import java.nio.ByteOrder; import java.util.Arrays; import org.apache.spark.unsafe.PlatformDependent; @@ -53,6 +54,8 @@ public final class UTF8String implements Comparable, Serializable { 5, 5, 5, 5, 6, 6}; + private static ByteOrder byteOrder = ByteOrder.nativeOrder(); + public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); /** @@ -175,18 +178,35 @@ public long getPrefix() { // If size is greater than 4, assume we have at least 8 bytes of data to fetch. // After getting the data, we use a mask to mask out data that is not part of the string. long p; - if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); - } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); - p = p & ((1L << numBytes * 8) - 1); - } else if (numBytes > 0) { - p = (long) PlatformDependent.UNSAFE.getInt(base, offset); - p = p & ((1L << numBytes * 8) - 1); + long mask = 0; + if (byteOrder == ByteOrder.LITTLE_ENDIAN) { + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } + p = java.lang.Long.reverseBytes(p); } else { - p = 0; + // byteOrder == ByteOrder.BIG_ENDIAN + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = ((long) PlatformDependent.UNSAFE.getInt(base, offset)) << 32; + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } } - p = java.lang.Long.reverseBytes(p); + p &= ~mask; return p; } From 1633d0a2612d94151f620c919425026150e69ae1 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 3 Aug 2015 17:42:03 -0700 Subject: [PATCH 0014/1824] [SPARK-9263] Added flags to exclude dependencies when using --packages While the functionality is there to exclude packages, there are no flags that allow users to exclude dependencies, in case of dependency conflicts. We should provide users with a flag to add dependency exclusions in case the packages are not resolved properly (or not available due to licensing). The flag I added was --packages-exclude, but I'm open on renaming it. I also added property flags in case people would like to use a conf file to provide dependencies, which is possible if there is a long list of dependencies or exclusions. cc andrewor14 vanzin pwendell Author: Burak Yavuz Closes #7599 from brkyvz/packages-exclusions and squashes the following commits: 636f410 [Burak Yavuz] addressed nits 6e54ede [Burak Yavuz] is this the culprit b5e508e [Burak Yavuz] Merge branch 'master' of github.com:apache/spark into packages-exclusions 154f5db [Burak Yavuz] addressed initial comments 1536d7a [Burak Yavuz] Added flags to exclude packages using --packages-exclude --- .../org/apache/spark/deploy/SparkSubmit.scala | 29 +++++++++--------- .../spark/deploy/SparkSubmitArguments.scala | 11 +++++++ .../spark/deploy/SparkSubmitUtilsSuite.scala | 30 +++++++++++++++++++ .../launcher/SparkSubmitOptionParser.java | 2 ++ 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0b39ee8fe3ba..31185c8e77de 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -24,6 +24,7 @@ import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path import org.apache.hadoop.security.UserGroupInformation import org.apache.ivy.Ivy @@ -37,6 +38,7 @@ import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} + import org.apache.spark.api.r.RUtils import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ @@ -275,21 +277,18 @@ object SparkSubmit { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code - val resolvedMavenCoordinates = - SparkSubmitUtils.resolveMavenCoordinates( - args.packages, Option(args.repositories), Option(args.ivyRepoPath)) - if (!resolvedMavenCoordinates.trim.isEmpty) { - if (args.jars == null || args.jars.trim.isEmpty) { - args.jars = resolvedMavenCoordinates + val exclusions: Seq[String] = + if (!StringUtils.isBlank(args.packagesExclusions)) { + args.packagesExclusions.split(",") } else { - args.jars += s",$resolvedMavenCoordinates" + Nil } + val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, + Some(args.repositories), Some(args.ivyRepoPath), exclusions = exclusions) + if (!StringUtils.isBlank(resolvedMavenCoordinates)) { + args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) if (args.isPython) { - if (args.pyFiles == null || args.pyFiles.trim.isEmpty) { - args.pyFiles = resolvedMavenCoordinates - } else { - args.pyFiles += s",$resolvedMavenCoordinates" - } + args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates) } } @@ -736,7 +735,7 @@ object SparkSubmit { * no files, into a single comma-separated string. */ private def mergeFileLists(lists: String*): String = { - val merged = lists.filter(_ != null) + val merged = lists.filterNot(StringUtils.isBlank) .flatMap(_.split(",")) .mkString(",") if (merged == "") null else merged @@ -938,7 +937,7 @@ private[spark] object SparkSubmitUtils { // are supplied to spark-submit val alternateIvyCache = ivyPath.getOrElse("") val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { + if (alternateIvyCache == null || alternateIvyCache.trim.isEmpty) { new File(ivySettings.getDefaultIvyUserDir, "jars") } else { ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) @@ -1010,7 +1009,7 @@ private[spark] object SparkSubmitUtils { } } - private def createExclusion( + private[deploy] def createExclusion( coords: String, ivySettings: IvySettings, ivyConfName: String): ExcludeRule = { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index b3710073e330..44852ce4e84a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null @@ -172,6 +173,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull + packagesExclusions = Option(packagesExclusions) + .orElse(sparkProperties.get("spark.jars.excludes")).orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) @@ -299,6 +303,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | childArgs [${childArgs.mkString(" ")}] | jars $jars | packages $packages + | packagesExclusions $packagesExclusions | repositories $repositories | verbose $verbose | @@ -391,6 +396,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case PACKAGES => packages = value + case PACKAGES_EXCLUDE => + packagesExclusions = value + case REPOSITORIES => repositories = value @@ -482,6 +490,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | maven repo, then maven central and any additional remote | repositories given by --repositories. The format for the | coordinates should be groupId:artifactId:version. + | --exclude-packages Comma-separated list of groupId:artifactId, to exclude while + | resolving the dependencies provided in --packages to avoid + | dependency conflicts. | --repositories Comma-separated list of additional remote repositories to | search for the maven coordinates given with --packages. | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 01ece1a10f46..63c346c1b890 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -95,6 +95,25 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(md.getDependencies.length === 2) } + test("excludes works correctly") { + val md = SparkSubmitUtils.getModuleDescriptor + val excludes = Seq("a:b", "c:d") + excludes.foreach { e => + md.addExcludeRule(SparkSubmitUtils.createExclusion(e + ":*", new IvySettings, "default")) + } + val rules = md.getAllExcludeRules + assert(rules.length === 2) + val rule1 = rules(0).getId.getModuleId + assert(rule1.getOrganisation === "a") + assert(rule1.getName === "b") + val rule2 = rules(1).getId.getModuleId + assert(rule2.getOrganisation === "c") + assert(rule2.getName === "d") + intercept[IllegalArgumentException] { + SparkSubmitUtils.createExclusion("e:f:g:h", new IvySettings, "default") + } + } + test("ivy path works correctly") { val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") @@ -168,4 +187,15 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") } } + + test("exclude dependencies end to end") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = "my.great.dep:mydep:0.5" + IvyTestUtils.withRepository(main, Some(dep), None) { repo => + val files = SparkSubmitUtils.resolveMavenCoordinates(main.toString, + Some(repo), None, Seq("my.great.dep:mydep"), isTest = true) + assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") + assert(files.indexOf("my.great.dep") < 0, "Returned excluded artifact") + } + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index b88bba883ac6..5779eb3fc0f7 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -51,6 +51,7 @@ class SparkSubmitOptionParser { protected final String MASTER = "--master"; protected final String NAME = "--name"; protected final String PACKAGES = "--packages"; + protected final String PACKAGES_EXCLUDE = "--exclude-packages"; protected final String PROPERTIES_FILE = "--properties-file"; protected final String PROXY_USER = "--proxy-user"; protected final String PY_FILES = "--py-files"; @@ -105,6 +106,7 @@ class SparkSubmitOptionParser { { NAME }, { NUM_EXECUTORS }, { PACKAGES }, + { PACKAGES_EXCLUDE }, { PRINCIPAL }, { PROPERTIES_FILE }, { PROXY_USER }, From 3b0e44490aebfba30afc147e4a34a63439d985c6 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Mon, 3 Aug 2015 18:20:40 -0700 Subject: [PATCH 0015/1824] [SPARK-8416] highlight and topping the executor threads in thread dumping page https://issues.apache.org/jira/browse/SPARK-8416 To facilitate debugging, I made this patch with three changes: * render the executor-thread and non executor-thread entries with different background colors * put the executor threads on the top of the list * sort the threads alphabetically Author: CodingCat Closes #7808 from CodingCat/SPARK-8416 and squashes the following commits: 34fc708 [CodingCat] fix className d7b79dd [CodingCat] lowercase threadName d032882 [CodingCat] sort alphabetically and change the css class name f0513b1 [CodingCat] change the color & group threads by name 2da6e06 [CodingCat] small fix 3fc9f36 [CodingCat] define classes in webui.css 8ee125e [CodingCat] highlight and put on top the executor threads in thread dumping page --- .../org/apache/spark/ui/static/webui.css | 8 +++++++ .../ui/exec/ExecutorThreadDumpPage.scala | 24 ++++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 648cd1b10480..04f3070d25b4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -224,3 +224,11 @@ span.additional-metric-title { a.expandbutton { cursor: pointer; } + +.executor-thread { + background: #E6E6E6; +} + +.non-executor-thread { + background: #FAFAFA; +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index f0ae95bb8c81..b0a2cb4aa4d4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -49,11 +49,29 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) val content = maybeThreadDump.map { threadDump => - val dumpRows = threadDump.map { thread => + val dumpRows = threadDump.sortWith { + case (threadTrace1, threadTrace2) => { + val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 + val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 + if (v1 == v2) { + threadTrace1.threadName.toLowerCase < threadTrace2.threadName.toLowerCase + } else { + v1 > v2 + } + } + }.map { thread => + val threadName = thread.threadName + val className = "accordion-heading " + { + if (threadName.contains("Executor task launch")) { + "executor-thread" + } else { + "non-executor-thread" + } + }
    -
    + @@ -594,6 +597,9 @@ rowMat = mat.toRowMatrix() # Convert to an IndexedRowMatrix. indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a BlockMatrix. +blockMat = mat.toBlockMatrix() {% endhighlight %}
    @@ -661,4 +667,39 @@ matA.validate(); BlockMatrix ata = matA.transpose().multiply(matA); {% endhighlight %}
    + +
    + +A [`BlockMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.BlockMatrix) +can be created from an `RDD` of sub-matrix blocks, where a sub-matrix block is a +`((blockRowIndex, blockColIndex), sub-matrix)` tuple. + +{% highlight python %} +from pyspark.mllib.linalg import Matrices +from pyspark.mllib.linalg.distributed import BlockMatrix + +# Create an RDD of sub-matrix blocks. +blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + +# Create a BlockMatrix from an RDD of sub-matrix blocks. +mat = BlockMatrix(blocks, 3, 2) + +# Get its size. +m = mat.numRows() # 6 +n = mat.numCols() # 2 + +# Get the blocks as an RDD of sub-matrix blocks. +blocksRDD = mat.blocks + +# Convert to a LocalMatrix. +localMat = mat.toLocalMatrix() + +# Convert to an IndexedRowMatrix. +indexedRowMat = mat.toIndexedRowMatrix() + +# Convert to a CoordinateMatrix. +coordinateMat = mat.toCoordinateMatrix() +{% endhighlight %} +
    diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index d2b3fae381ac..f585aacd452e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1128,6 +1128,21 @@ private[python] class PythonMLLibAPI extends Serializable { new CoordinateMatrix(entries, numRows, numCols) } + /** + * Wrapper around BlockMatrix constructor. + */ + def createBlockMatrix(blocks: DataFrame, rowsPerBlock: Int, colsPerBlock: Int, + numRows: Long, numCols: Long): BlockMatrix = { + // We use DataFrames for serialization of sub-matrix blocks from + // Python, so map each Row in the DataFrame back to a + // ((blockRowIndex, blockColIndex), sub-matrix) tuple. + val blockTuples = blocks.map { + case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) => + ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix) + } + new BlockMatrix(blockTuples, rowsPerBlock, colsPerBlock, numRows, numCols) + } + /** * Return the rows of an IndexedRowMatrix. */ @@ -1147,6 +1162,16 @@ private[python] class PythonMLLibAPI extends Serializable { val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext) sqlContext.createDataFrame(coordinateMatrix.entries) } + + /** + * Return the sub-matrix blocks of a BlockMatrix. + */ + def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { + // We use DataFrames for serialization of sub-matrix blocks to + // Python, so return a DataFrame. + val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext) + sqlContext.createDataFrame(blockMatrix.blocks) + } } /** diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 666d83301956..aec407de90aa 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -28,11 +28,12 @@ from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector +from pyspark.mllib.linalg import _convert_to_vector, Matrix __all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow', - 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix'] + 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix', + 'BlockMatrix'] class DistributedMatrix(object): @@ -322,6 +323,35 @@ def toCoordinateMatrix(self): java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix") return CoordinateMatrix(java_coordinate_matrix) + def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + """ + Convert this matrix to a BlockMatrix. + + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(6, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows).toBlockMatrix() + + >>> # This IndexedRowMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # BlockMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> print(mat.numCols()) + 3 + """ + java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix", + rowsPerBlock, + colsPerBlock) + return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + class MatrixEntry(object): """ @@ -476,19 +506,18 @@ def toRowMatrix(self): >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toRowMatrix() >>> # This CoordinateMatrix will have 7 effective rows, due to >>> # the highest row index being 6, but the ensuing RowMatrix >>> # will only have 2 rows since there are only entries on 2 >>> # unique rows. - >>> mat = CoordinateMatrix(entries).toRowMatrix() >>> print(mat.numRows()) 2 >>> # This CoordinateMatrix will have 5 columns, due to the >>> # highest column index being 4, and the ensuing RowMatrix >>> # will have 5 columns as well. - >>> mat = CoordinateMatrix(entries).toRowMatrix() >>> print(mat.numCols()) 5 """ @@ -501,33 +530,320 @@ def toIndexedRowMatrix(self): >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix() >>> # This CoordinateMatrix will have 7 effective rows, due to >>> # the highest row index being 6, and the ensuing >>> # IndexedRowMatrix will have 7 rows as well. - >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix() >>> print(mat.numRows()) 7 >>> # This CoordinateMatrix will have 5 columns, due to the >>> # highest column index being 4, and the ensuing >>> # IndexedRowMatrix will have 5 columns as well. - >>> mat = CoordinateMatrix(entries).toIndexedRowMatrix() >>> print(mat.numCols()) 5 """ java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") return IndexedRowMatrix(java_indexed_row_matrix) + def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): + """ + Convert this matrix to a BlockMatrix. + + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + + >>> entries = sc.parallelize([MatrixEntry(0, 0, 1.2), + ... MatrixEntry(6, 4, 2.1)]) + >>> mat = CoordinateMatrix(entries).toBlockMatrix() + + >>> # This CoordinateMatrix will have 7 effective rows, due to + >>> # the highest row index being 6, and the ensuing + >>> # BlockMatrix will have 7 rows as well. + >>> print(mat.numRows()) + 7 + + >>> # This CoordinateMatrix will have 5 columns, due to the + >>> # highest column index being 4, and the ensuing + >>> # BlockMatrix will have 5 columns as well. + >>> print(mat.numCols()) + 5 + """ + java_block_matrix = self._java_matrix_wrapper.call("toBlockMatrix", + rowsPerBlock, + colsPerBlock) + return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + + +def _convert_to_matrix_block_tuple(block): + if (isinstance(block, tuple) and len(block) == 2 + and isinstance(block[0], tuple) and len(block[0]) == 2 + and isinstance(block[1], Matrix)): + blockRowIndex = int(block[0][0]) + blockColIndex = int(block[0][1]) + subMatrix = block[1] + return ((blockRowIndex, blockColIndex), subMatrix) + else: + raise TypeError("Cannot convert type %s into a sub-matrix block tuple" % type(block)) + + +class BlockMatrix(DistributedMatrix): + """ + .. note:: Experimental + + Represents a distributed matrix in blocks of local matrices. + + :param blocks: An RDD of sub-matrix blocks + ((blockRowIndex, blockColIndex), sub-matrix) that + form this distributed matrix. If multiple blocks + with the same index exist, the results for + operations like add and multiply will be + unpredictable. + :param rowsPerBlock: Number of rows that make up each block. + The blocks forming the final rows are not + required to have the given number of rows. + :param colsPerBlock: Number of columns that make up each block. + The blocks forming the final columns are not + required to have the given number of columns. + :param numRows: Number of rows of this matrix. If the supplied + value is less than or equal to zero, the number + of rows will be calculated when `numRows` is + invoked. + :param numCols: Number of columns of this matrix. If the supplied + value is less than or equal to zero, the number + of columns will be calculated when `numCols` is + invoked. + """ + def __init__(self, blocks, rowsPerBlock, colsPerBlock, numRows=0, numCols=0): + """ + Note: This docstring is not shown publicly. + + Create a wrapper over a Java BlockMatrix. + + Publicly, we require that `blocks` be an RDD. However, for + internal usage, `blocks` can also be a Java BlockMatrix + object, in which case we can wrap it directly. This + assists in clean matrix conversions. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + + >>> mat_diff = BlockMatrix(blocks, 3, 2) + >>> (mat_diff._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + False + + >>> mat_same = BlockMatrix(mat._java_matrix_wrapper._java_model, 3, 2) + >>> (mat_same._java_matrix_wrapper._java_model == + ... mat._java_matrix_wrapper._java_model) + True + """ + if isinstance(blocks, RDD): + blocks = blocks.map(_convert_to_matrix_block_tuple) + # We use DataFrames for serialization of sub-matrix blocks + # from Python, so first convert the RDD to a DataFrame on + # this side. This will convert each sub-matrix block + # tuple to a Row containing the 'blockRowIndex', + # 'blockColIndex', and 'subMatrix' values, which can + # each be easily serialized. We will convert back to + # ((blockRowIndex, blockColIndex), sub-matrix) tuples on + # the Scala side. + java_matrix = callMLlibFunc("createBlockMatrix", blocks.toDF(), + int(rowsPerBlock), int(colsPerBlock), + long(numRows), long(numCols)) + elif (isinstance(blocks, JavaObject) + and blocks.getClass().getSimpleName() == "BlockMatrix"): + java_matrix = blocks + else: + raise TypeError("blocks should be an RDD of sub-matrix blocks as " + "((int, int), matrix) tuples, got %s" % type(blocks)) + + self._java_matrix_wrapper = JavaModelWrapper(java_matrix) + + @property + def blocks(self): + """ + The RDD of sub-matrix blocks + ((blockRowIndex, blockColIndex), sub-matrix) that form this + distributed matrix. + + >>> mat = BlockMatrix( + ... sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]), 3, 2) + >>> blocks = mat.blocks + >>> blocks.first() + ((0, 0), DenseMatrix(3, 2, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 0)) + + """ + # We use DataFrames for serialization of sub-matrix blocks + # from Java, so we first convert the RDD of blocks to a + # DataFrame on the Scala/Java side. Then we map each Row in + # the DataFrame back to a sub-matrix block on this side. + blocks_df = callMLlibFunc("getMatrixBlocks", self._java_matrix_wrapper._java_model) + blocks = blocks_df.map(lambda row: ((row[0][0], row[0][1]), row[1])) + return blocks + + @property + def rowsPerBlock(self): + """ + Number of rows that make up each block. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.rowsPerBlock + 3 + """ + return self._java_matrix_wrapper.call("rowsPerBlock") + + @property + def colsPerBlock(self): + """ + Number of columns that make up each block. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.colsPerBlock + 2 + """ + return self._java_matrix_wrapper.call("colsPerBlock") + + @property + def numRowBlocks(self): + """ + Number of rows of blocks in the BlockMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.numRowBlocks + 2 + """ + return self._java_matrix_wrapper.call("numRowBlocks") + + @property + def numColBlocks(self): + """ + Number of columns of blocks in the BlockMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2) + >>> mat.numColBlocks + 1 + """ + return self._java_matrix_wrapper.call("numColBlocks") + + def numRows(self): + """ + Get or compute the number of rows. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + + >>> mat = BlockMatrix(blocks, 3, 2) + >>> print(mat.numRows()) + 6 + + >>> mat = BlockMatrix(blocks, 3, 2, 7, 6) + >>> print(mat.numRows()) + 7 + """ + return self._java_matrix_wrapper.call("numRows") + + def numCols(self): + """ + Get or compute the number of cols. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + + >>> mat = BlockMatrix(blocks, 3, 2) + >>> print(mat.numCols()) + 2 + + >>> mat = BlockMatrix(blocks, 3, 2, 7, 6) + >>> print(mat.numCols()) + 6 + """ + return self._java_matrix_wrapper.call("numCols") + + def toLocalMatrix(self): + """ + Collect the distributed matrix on the driver as a DenseMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2).toLocalMatrix() + + >>> # This BlockMatrix will have 6 effective rows, due to + >>> # having two sub-matrix blocks stacked, each with 3 rows. + >>> # The ensuing DenseMatrix will also have 6 rows. + >>> print(mat.numRows) + 6 + + >>> # This BlockMatrix will have 2 effective columns, due to + >>> # having two sub-matrix blocks stacked, each with 2 + >>> # columns. The ensuing DenseMatrix will also have 2 columns. + >>> print(mat.numCols) + 2 + """ + return self._java_matrix_wrapper.call("toLocalMatrix") + + def toIndexedRowMatrix(self): + """ + Convert this matrix to an IndexedRowMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])), + ... ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))]) + >>> mat = BlockMatrix(blocks, 3, 2).toIndexedRowMatrix() + + >>> # This BlockMatrix will have 6 effective rows, due to + >>> # having two sub-matrix blocks stacked, each with 3 rows. + >>> # The ensuing IndexedRowMatrix will also have 6 rows. + >>> print(mat.numRows()) + 6 + + >>> # This BlockMatrix will have 2 effective columns, due to + >>> # having two sub-matrix blocks stacked, each with 2 columns. + >>> # The ensuing IndexedRowMatrix will also have 2 columns. + >>> print(mat.numCols()) + 2 + """ + java_indexed_row_matrix = self._java_matrix_wrapper.call("toIndexedRowMatrix") + return IndexedRowMatrix(java_indexed_row_matrix) + + def toCoordinateMatrix(self): + """ + Convert this matrix to a CoordinateMatrix. + + >>> blocks = sc.parallelize([((0, 0), Matrices.dense(1, 2, [1, 2])), + ... ((1, 0), Matrices.dense(1, 2, [7, 8]))]) + >>> mat = BlockMatrix(blocks, 1, 2).toCoordinateMatrix() + >>> mat.entries.take(3) + [MatrixEntry(0, 0, 1.0), MatrixEntry(0, 1, 2.0), MatrixEntry(1, 0, 7.0)] + """ + java_coordinate_matrix = self._java_matrix_wrapper.call("toCoordinateMatrix") + return CoordinateMatrix(java_coordinate_matrix) + def _test(): import doctest from pyspark import SparkContext from pyspark.sql import SQLContext + from pyspark.mllib.linalg import Matrices import pyspark.mllib.linalg.distributed globs = pyspark.mllib.linalg.distributed.__dict__.copy() globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) globs['sqlContext'] = SQLContext(globs['sc']) + globs['Matrices'] = Matrices (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: From 23d982204bb9ef74d3b788a32ce6608116968719 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 5 Aug 2015 09:01:45 -0700 Subject: [PATCH 0062/1824] [SPARK-9141] [SQL] Remove project collapsing from DataFrame API Currently we collapse successive projections that are added by `withColumn`. However, this optimization violates the constraint that adding nodes to a plan will never change its analyzed form and thus breaks caching. Instead of doing early optimization, in this PR I just fix some low-hanging slowness in the analyzer. In particular, I add a mechanism for skipping already analyzed subplans, `resolveOperators` and `resolveExpression`. Since trees are generally immutable after construction, it's safe to annotate a plan as already analyzed as any transformation will create a new tree with this bit no longer set. Together these result in a faster analyzer than before, even with added timing instrumentation. ``` Original Code [info] 3430ms [info] 2205ms [info] 1973ms [info] 1982ms [info] 1916ms Without Project Collapsing in DataFrame [info] 44610ms [info] 45977ms [info] 46423ms [info] 46306ms [info] 54723ms With analyzer optimizations [info] 6394ms [info] 4630ms [info] 4388ms [info] 4093ms [info] 4113ms With resolveOperators [info] 2495ms [info] 1380ms [info] 1685ms [info] 1414ms [info] 1240ms ``` Author: Michael Armbrust Closes #7920 from marmbrus/withColumnCache and squashes the following commits: 2145031 [Michael Armbrust] fix hive udfs tests 5a5a525 [Michael Armbrust] remove wrong comment 7a507d5 [Michael Armbrust] style b59d710 [Michael Armbrust] revert small change 1fa5949 [Michael Armbrust] move logic into LogicalPlan, add tests 0e2cb43 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into withColumnCache c926e24 [Michael Armbrust] naming e593a2d [Michael Armbrust] style f5a929e [Michael Armbrust] [SPARK-9141][SQL] Remove project collapsing from DataFrame API 38b1c83 [Michael Armbrust] WIP --- .../sql/catalyst/analysis/Analyzer.scala | 28 ++++---- .../sql/catalyst/analysis/CheckAnalysis.scala | 3 + .../catalyst/analysis/HiveTypeCoercion.scala | 30 ++++---- .../spark/sql/catalyst/plans/QueryPlan.scala | 5 +- .../catalyst/plans/logical/LogicalPlan.scala | 51 ++++++++++++- .../sql/catalyst/rules/RuleExecutor.scala | 22 ++++++ .../spark/sql/catalyst/trees/TreeNode.scala | 64 ++++------------- .../sql/catalyst/plans/LogicalPlanSuite.scala | 72 +++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 6 +- .../spark/sql/execution/SparkPlan.scala | 2 +- .../spark/sql/execution/pythonUDFs.scala | 2 +- .../apache/spark/sql/CachedTableSuite.scala | 20 ++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 12 ---- .../execution/HiveCompatibilitySuite.scala | 5 ++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 5 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 3 +- .../sql/hive/execution/HiveUDFSuite.scala | 8 +-- .../sql/hive/execution/SQLQuerySuite.scala | 2 + 18 files changed, 234 insertions(+), 106 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ca17f3e3d06f..6de31f42dd30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -90,7 +90,7 @@ class Analyzer( */ object CTESubstitution extends Rule[LogicalPlan] { // TODO allow subquery to define CTE - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case With(child, relations) => substituteCTE(child, relations) case other => other } @@ -116,7 +116,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -140,7 +140,7 @@ class Analyzer( object ResolveAliases extends Rule[LogicalPlan] { private def assignAliases(exprs: Seq[NamedExpression]) = { // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need - // to transform down the whole tree. + // to resolveOperator down the whole tree. exprs.zipWithIndex.map { case (u @ UnresolvedAlias(child), i) => child match { @@ -156,7 +156,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => Aggregate(groups, assignAliases(aggs), child) @@ -198,7 +198,7 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case a if !a.childrenResolved => a // be sure all of the children are resolved. case a: Cube => GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) @@ -261,7 +261,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => i.copy(table = EliminateSubQueries(getTable(u))) case u: UnresolvedRelation => @@ -274,7 +274,7 @@ class Analyzer( * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -444,7 +444,7 @@ class Analyzer( * remove these attributes after sorting. */ object ResolveSortReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child) @@ -519,7 +519,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case q: LogicalPlan => q transformExpressions { case u @ UnresolvedFunction(name, children, isDistinct) => @@ -551,7 +551,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -571,7 +571,7 @@ class Analyzer( * aggregates and then projects them away above the filter. */ object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) if aggregate.resolved && containsAggregate(havingCondition) => @@ -601,7 +601,7 @@ class Analyzer( * [[AnalysisException]] is throw. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: Generate if !p.child.resolved || !p.generator.resolved => p case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -872,6 +872,8 @@ class Analyzer( // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + + // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) @@ -927,7 +929,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: Project => p case f: Filter => f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 187b238045f8..39f554c137c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -47,6 +47,7 @@ trait CheckAnalysis { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + case p if p.analyzed => // Skip already analyzed sub-plans case operator: LogicalPlan => operator transformExpressionsUp { @@ -179,5 +180,7 @@ trait CheckAnalysis { } } extendedCheckRules.foreach(_(plan)) + + plan.foreach(_.setAnalyzed()) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 490f3dc07b6e..970f3c8282c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -144,7 +144,8 @@ object HiveTypeCoercion { * instances higher in the query tree. */ object PropagateTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // No propagation required for leaf nodes. case q: LogicalPlan if q.children.isEmpty => q @@ -225,7 +226,9 @@ object HiveTypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if p.analyzed => p + case u @ Union(left, right) if u.childrenResolved && !u.resolved => val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right) Union(newLeft, newRight) @@ -242,7 +245,7 @@ object HiveTypeCoercion { * Promotes strings that appear in arithmetic expressions. */ object PromoteStrings extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -305,7 +308,7 @@ object HiveTypeCoercion { * Convert all expressions in in() list to the left operator type */ object InConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -372,7 +375,8 @@ object HiveTypeCoercion { ChangeDecimalPrecision(Cast(e, dataType)) } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + // fix decimal precision for expressions case q => q.transformExpressions { // Skip nodes whose children have not been resolved yet @@ -466,7 +470,7 @@ object HiveTypeCoercion { )) } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -508,7 +512,7 @@ object HiveTypeCoercion { * truncated version of this number. */ object StringToIntegralCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -521,7 +525,7 @@ object HiveTypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -575,7 +579,7 @@ object HiveTypeCoercion { * converted to fractional types. */ object Division extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.resolved => e @@ -592,7 +596,7 @@ object HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes) @@ -628,7 +632,7 @@ object HiveTypeCoercion { * Coerces the type of different branches of If statement to a common type. */ object IfCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => @@ -652,7 +656,7 @@ object HiveTypeCoercion { private val acceptedTypes = Seq(DateType, TimestampType, StringType) - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -669,7 +673,7 @@ object HiveTypeCoercion { * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c610f70d3843..55286f9f2fc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -92,7 +93,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map(recursiveTransform).toArray - if (changed) makeCopy(newArgs) else this + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } /** @@ -124,7 +125,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map(recursiveTransform).toArray - if (changed) makeCopy(newArgs) else this + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } /** Returns the result of running [[transformExpressions]] on this node diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index bedeaf06adf1..9b52f020093f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,11 +22,60 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { + private var _analyzed: Boolean = false + + /** + * Marks this plan as already analyzed. This should only be called by CheckAnalysis. + */ + private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } + + /** + * Returns true if this node and its children have already been gone through analysis and + * verification. Note that this is only an optimization used to avoid analyzing trees that + * have already been analyzed, and can be reset by transformations. + */ + def analyzed: Boolean = _analyzed + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order). When `rule` does not apply to a given node, it is left + * unchanged. This function is similar to `transformUp`, but skips sub-trees that have already + * been marked as analyzed. + * + * @param rule the function use to transform this nodes children + */ + def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = { + if (!analyzed) { + val afterRuleOnChildren = transformChildren(rule, (t, r) => t.resolveOperators(r)) + if (this fastEquals afterRuleOnChildren) { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(this, identity[LogicalPlan]) + } + } else { + CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(afterRuleOnChildren, identity[LogicalPlan]) + } + } + } else { + this + } + } + + /** + * Recursively transforms the expressions of a tree, skipping nodes that have already + * been analyzed. + */ + def resolveExpressions(r: PartialFunction[Expression, Expression]): LogicalPlan = { + this resolveOperators { + case p => p.transformExpressions(r) + } + } + /** * Computes [[Statistics]] for this plan. The default implementation assumes the output * cardinality is the product of of all child plan's cardinality, i.e. applies in the case diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 3f9858b0c4a4..8b824511a79d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -21,6 +21,23 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide +import scala.collection.mutable + +object RuleExecutor { + protected val timeMap = new mutable.HashMap[String, Long].withDefault(_ => 0) + + /** Resets statistics about time spent running specific rules */ + def resetTime(): Unit = timeMap.clear() + + /** Dump statistics about time spent running specific rules. */ + def dumpTimeSpent(): String = { + val maxSize = timeMap.keys.map(_.toString.length).max + timeMap.toSeq.sortBy(_._2).reverseMap { case (k, v) => + s"${k.padTo(maxSize, " ").mkString} $v" + }.mkString("\n") + } +} + abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** @@ -41,6 +58,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected val batches: Seq[Batch] + /** * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. @@ -58,7 +76,11 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { while (continue) { curPlan = batch.rules.foldLeft(curPlan) { case (plan, rule) => + val startTime = System.nanoTime() val result = rule(plan) + val runTime = System.nanoTime() - startTime + RuleExecutor.timeMap(rule.ruleName) = RuleExecutor.timeMap(rule.ruleName) + runTime + if (!result.fastEquals(plan)) { logTrace( s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 122e9fc5ed77..7971e25188e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -149,7 +149,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns a copy of this node where `f` has been applied to all the nodes children. */ - def mapChildren(f: BaseType => BaseType): this.type = { + def mapChildren(f: BaseType => BaseType): BaseType = { var changed = false val newArgs = productIterator.map { case arg: TreeNode[_] if containsChild(arg) => @@ -170,7 +170,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * Returns a copy of this node with the children replaced. * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. */ - def withNewChildren(newChildren: Seq[BaseType]): this.type = { + def withNewChildren(newChildren: Seq[BaseType]): BaseType = { assert(newChildren.size == children.size, "Incorrect number of children") var changed = false val remainingNewChildren = newChildren.toBuffer @@ -229,9 +229,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // Check if unchanged and then possibly return old copy to avoid gc churn. if (this fastEquals afterRule) { - transformChildrenDown(rule) + transformChildren(rule, (t, r) => t.transformDown(r)) } else { - afterRule.transformChildrenDown(rule) + afterRule.transformChildren(rule, (t, r) => t.transformDown(r)) } } @@ -240,11 +240,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * this node. When `rule` does not apply to a given node it is left unchanged. * @param rule the function used to transform this nodes children */ - def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { + protected def transformChildren( + rule: PartialFunction[BaseType, BaseType], + nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = { var changed = false val newArgs = productIterator.map { case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true newChild @@ -252,7 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case Some(arg: TreeNode[_]) if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true Some(newChild) @@ -263,7 +265,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { changed = true newChild @@ -285,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param rule the function use to transform this nodes children */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { - val afterRuleOnChildren = transformChildrenUp(rule) + val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r)) if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) @@ -297,44 +299,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } - def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = { - var changed = false - val newArgs = productIterator.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case Some(arg: TreeNode[_]) if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { - changed = true - Some(newChild) - } else { - Some(arg) - } - case m: Map[_, _] => m - case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = arg.asInstanceOf[BaseType].transformUp(rule) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case other => other - } - case nonChild: AnyRef => nonChild - case null => null - }.toArray - if (changed) makeCopy(newArgs) else this - } - /** * Args to the constructor that should be copied, but not transformed. * These are appended to the transformed args automatically by makeCopy @@ -348,7 +312,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * that are not present in the productIterator. * @param newArgs the new product arguments. */ - def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") { val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") @@ -359,9 +323,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { CurrentOrigin.withOrigin(origin) { // Skip no-arg constructors that are just there for kryo. if (otherCopyArgs.isEmpty) { - defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] + defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType] } else { - defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] + defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType] } } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala new file mode 100644 index 000000000000..797b29f23cbb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util._ + +/** + * Provides helper methods for comparing plans. + */ +class LogicalPlanSuite extends SparkFunSuite { + private var invocationCount = 0 + private val function: PartialFunction[LogicalPlan, LogicalPlan] = { + case p: Project => + invocationCount += 1 + p + } + + private val testRelation = LocalRelation() + + test("resolveOperator runs on operators") { + invocationCount = 0 + val plan = Project(Nil, testRelation) + plan resolveOperators function + + assert(invocationCount === 1) + } + + test("resolveOperator runs on operators recursively") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, testRelation)) + plan resolveOperators function + + assert(invocationCount === 2) + } + + test("resolveOperator skips all ready resolved plans") { + invocationCount = 0 + val plan = Project(Nil, Project(Nil, testRelation)) + plan.foreach(_.setAnalyzed()) + plan resolveOperators function + + assert(invocationCount === 0) + } + + test("resolveOperator skips partially resolved plans") { + invocationCount = 0 + val plan1 = Project(Nil, testRelation) + val plan2 = Project(Nil, plan1) + plan1.foreach(_.setAnalyzed()) + plan2 resolveOperators function + + assert(invocationCount === 1) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index db15711202b7..e57acec59d32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.util.Properties +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.unsafe.types.UTF8String import scala.language.implicitConversions @@ -54,7 +55,6 @@ private[sql] object DataFrame { } } - /** * :: Experimental :: * A distributed collection of data organized into named columns. @@ -690,9 +690,7 @@ class DataFrame private[sql]( case Column(explode: Explode) => MultiAlias(explode, Nil) case Column(expr: Expression) => Alias(expr, expr.prettyString)() } - // When user continuously call `select`, speed up analysis by collapsing `Project` - import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing - Project(namedExpressions.toSeq, ProjectCollapsing(logicalPlan)) + Project(namedExpressions.toSeq, logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 73b237fffece..dbc0cefbe2e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -67,7 +67,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private val prepareCalled = new AtomicBoolean(false) /** Overridden make copy also propogates sqlContext to copied plan. */ - override def makeCopy(newArgs: Array[AnyRef]): this.type = { + override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { SparkPlan.currentContext.set(sqlContext) super.makeCopy(newArgs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index dedc7c4dfb4d..59f8b079ab33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -65,7 +65,7 @@ private[spark] case class PythonUDF( * multiple child operators. */ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index e9dd7ef226e4..a88df91b1001 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.functions._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) @@ -50,6 +51,25 @@ class CachedTableSuite extends QueryTest { ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } + test("withColumn doesn't invalidate cached dataframe") { + var evalCount = 0 + val myUDF = udf((x: String) => { evalCount += 1; "result" }) + val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) + df.cache() + + df.collect() + assert(evalCount === 1) + + df.collect() + assert(evalCount === 1) + + val df2 = df.withColumn("newColumn", lit(1)) + df2.collect() + + // We should not reevaluate the cached dataframe + assert(evalCount === 1) + } + test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b8f10b00f569..f9cc6d1f3c25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -686,18 +686,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { Seq(Row(2, 1, 2), Row(1, 1, 1))) } - test("SPARK-7276: Project collapse for continuous select") { - var df = testData - for (i <- 1 to 5) { - df = df.select($"*") - } - - import org.apache.spark.sql.catalyst.plans.logical.Project - // make sure df have at most two Projects - val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project] - assert(!p.child.isInstanceOf[Project]) - } - test("SPARK-7150 range api") { // numSlice is greater than length val res1 = sqlContext.range(0, 10, 1, 15).select("id") diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index d4fc6c2b6ebc..ab309e0a1d36 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf @@ -50,6 +51,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + RuleExecutor.resetTime() } override def afterAll() { @@ -58,6 +60,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + + // For debugging dump some statistics about how much time was spent in various optimizer rules. + logWarning(RuleExecutor.dumpTimeSpent()) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 16c186627f6c..6b37af99f467 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -391,7 +391,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive */ object ParquetConversions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - if (!plan.resolved) { + if (!plan.resolved || plan.analyzed) { return plan } @@ -418,8 +418,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive (relation, parquetRelation, attributedRewrites) // Read path - case p @ PhysicalOperation(_, _, relation: MetastoreRelation) - if hive.convertMetastoreParquet && + case relation: MetastoreRelation if hive.convertMetastoreParquet && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 8a86a87368f2..7182246e466a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -133,8 +133,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient private lazy val conversionHelper = new ConversionHelper(method, arguments) - @transient - lazy val dataType = javaClassToDataType(method.getReturnType) + val dataType = javaClassToDataType(method.getReturnType) @transient lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 7069afc9f7da..10f2902e5eef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -182,7 +182,7 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToListString(s) FROM inputTable") } - assert(errMsg.getMessage === "List type in java is unsupported because " + + assert(errMsg.getMessage contains "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") @@ -197,7 +197,7 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToListInt(s) FROM inputTable") } - assert(errMsg.getMessage === "List type in java is unsupported because " + + assert(errMsg.getMessage contains "List type in java is unsupported because " + "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") @@ -213,7 +213,7 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToStringIntMap(s) FROM inputTable") } - assert(errMsg.getMessage === "Map type in java is unsupported because " + + assert(errMsg.getMessage contains "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") @@ -229,7 +229,7 @@ class HiveUDFSuite extends QueryTest { val errMsg = intercept[AnalysisException] { sql("SELECT testUDFToIntIntMap(s) FROM inputTable") } - assert(errMsg.getMessage === "Map type in java is unsupported because " + + assert(errMsg.getMessage contains "Map type in java is unsupported because " + "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ff9a3694d612..1dff07a6de8a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -948,6 +948,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { } test("SPARK-7595: Window will cause resolve failed with self join") { + sql("SELECT * FROM src") // Force loading of src table. + checkAnswer(sql( """ |with From 7a969a6967c4ecc0f004b73bff27a75257a94e86 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Wed, 5 Aug 2015 10:16:12 -0700 Subject: [PATCH 0063/1824] [SPARK-9519] [YARN] Confirm stop sc successfully when application was killed Currently, when we kill application on Yarn, then will call sc.stop() at Yarn application state monitor thread, then in YarnClientSchedulerBackend.stop() will call interrupt this will cause SparkContext not stop fully as we will wait executor to exit. Author: linweizhong Closes #7846 from Sephiroth-Lin/SPARK-9519 and squashes the following commits: 1ae736d [linweizhong] Update comments 2e8e365 [linweizhong] Add comment explaining the code ad0e23b [linweizhong] Update 243d2c7 [linweizhong] Confirm stop sc successfully when application was killed --- .../cluster/YarnClientSchedulerBackend.scala | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d97fa2e2151b..d225061fcd1b 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -33,7 +33,7 @@ private[spark] class YarnClientSchedulerBackend( private var client: Client = null private var appId: ApplicationId = null - private var monitorThread: Thread = null + private var monitorThread: MonitorThread = null /** * Create a Yarn client to submit an application to the ResourceManager. @@ -131,24 +131,42 @@ private[spark] class YarnClientSchedulerBackend( } } + /** + * We create this class for SPARK-9519. Basically when we interrupt the monitor thread it's + * because the SparkContext is being shut down(sc.stop() called by user code), but if + * monitorApplication return, it means the Yarn application finished before sc.stop() was called, + * which means we should call sc.stop() here, and we don't allow the monitor to be interrupted + * before SparkContext stops successfully. + */ + private class MonitorThread extends Thread { + private var allowInterrupt = true + + override def run() { + try { + val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + logError(s"Yarn application has already exited with state $state!") + allowInterrupt = false + sc.stop() + } catch { + case e: InterruptedException => logInfo("Interrupting monitor thread") + } + } + + def stopMonitor(): Unit = { + if (allowInterrupt) { + this.interrupt() + } + } + } + /** * Monitor the application state in a separate thread. * If the application has exited for any reason, stop the SparkContext. * This assumes both `client` and `appId` have already been set. */ - private def asyncMonitorApplication(): Thread = { + private def asyncMonitorApplication(): MonitorThread = { assert(client != null && appId != null, "Application has not been submitted yet!") - val t = new Thread { - override def run() { - try { - val (state, _) = client.monitorApplication(appId, logApplicationReport = false) - logError(s"Yarn application has already exited with state $state!") - sc.stop() - } catch { - case e: InterruptedException => logInfo("Interrupting monitor thread") - } - } - } + val t = new MonitorThread t.setName("Yarn application state monitor") t.setDaemon(true) t @@ -160,7 +178,7 @@ private[spark] class YarnClientSchedulerBackend( override def stop() { assert(client != null, "Attempted to stop this scheduler before starting it!") if (monitorThread != null) { - monitorThread.interrupt() + monitorThread.stopMonitor() } super.stop() client.stop() @@ -174,5 +192,4 @@ private[spark] class YarnClientSchedulerBackend( super.applicationId } } - } From 1f8c364b9c6636f06986f5f80d5a49b7a7772ac3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 5 Aug 2015 11:03:02 -0700 Subject: [PATCH 0064/1824] [SPARK-9141] [SQL] [MINOR] Fix comments of PR #7920 This is a follow-up of https://github.com/apache/spark/pull/7920 to fix comments. Author: Yin Huai Closes #7964 from yhuai/SPARK-9141-follow-up and squashes the following commits: 4d0ee80 [Yin Huai] Fix comments. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 +-- .../org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6de31f42dd30..82158e61e3fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -140,7 +140,7 @@ class Analyzer( object ResolveAliases extends Rule[LogicalPlan] { private def assignAliases(exprs: Seq[NamedExpression]) = { // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need - // to resolveOperator down the whole tree. + // to traverse the whole tree. exprs.zipWithIndex.map { case (u @ UnresolvedAlias(child), i) => child match { @@ -873,7 +873,6 @@ class Analyzer( // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 797b29f23cbb..455a3810c719 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ /** - * Provides helper methods for comparing plans. + * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly + * skips sub-trees that have already been marked as analyzed. */ class LogicalPlanSuite extends SparkFunSuite { private var invocationCount = 0 From e1e05873fc75781b6dd3f7fadbfb57824f83054e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 5 Aug 2015 11:38:56 -0700 Subject: [PATCH 0065/1824] [SPARK-9403] [SQL] Add codegen support in In and InSet This continues tarekauel's work in #7778. Author: Liang-Chi Hsieh Author: Tarek Auel Closes #7893 from viirya/codegen_in and squashes the following commits: 81ff97b [Liang-Chi Hsieh] For comments. 47761c6 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in cf4bf41 [Liang-Chi Hsieh] For comments. f532b3c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 446bbcd [Liang-Chi Hsieh] Fix bug. b3d0ab4 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 4610eff [Liang-Chi Hsieh] Relax the types of references and update optimizer test. 224f18e [Liang-Chi Hsieh] Beef up the test cases for In and InSet to include all primitive data types. 86dc8aa [Liang-Chi Hsieh] Only convert In to InSet when the number of items in set is more than the threshold. b7ded7e [Tarek Auel] [SPARK-9403][SQL] codeGen in / inSet --- .../sql/catalyst/expressions/predicates.scala | 63 +++++++++++++++++-- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../catalyst/expressions/PredicateSuite.scala | 37 ++++++++++- .../catalyst/optimizer/OptimizeInSuite.scala | 14 ++++- .../datasources/DataSourceStrategy.scala | 7 +++ .../spark/sql/ColumnExpressionSuite.scala | 6 ++ 6 files changed, 119 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b69bbabee7e8..68c832d7194d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -97,32 +100,80 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ -case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback { +case class In(value: Expression, list: Seq[Expression]) extends Predicate + with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (list.exists(l => l.dataType != value.dataType)) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be same type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + override def children: Seq[Expression] = value +: list - override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) list.exists(e => e.eval(input) == evaluatedValue) } -} + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val valueGen = value.gen(ctx) + val listGen = list.map(_.gen(ctx)) + val listCode = listGen.map(x => + s""" + if (!${ev.primitive}) { + ${x.code} + if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) { + ${ev.primitive} = true; + } + } + """).mkString("\n") + s""" + ${valueGen.code} + boolean ${ev.primitive} = false; + boolean ${ev.isNull} = false; + $listCode + """ + } +} /** * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(child: Expression, hset: Set[Any]) - extends UnaryExpression with Predicate with CodegenFallback { +case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { - override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { hset.contains(child.eval(input)) } + + def getHSet(): Set[Any] = hset + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val setName = classOf[Set[Any]].getName + val InSetName = classOf[InSet].getName + val childGen = child.gen(ctx) + ctx.references += this + val hsetTerm = ctx.freshName("hset") + ctx.addMutableState(setName, hsetTerm, + s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();") + s""" + ${childGen.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive}); + """ + } } case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 29d706dcb39a..4ab5ac2c61e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -393,7 +393,7 @@ object ConstantFolding extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => + case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) && list.size > 10 => val hSet = list.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index d7eb13c50b13..7beef71845e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType} +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.types._ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) + + val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, + LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + primitiveTypes.map { t => + val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val inputData = Seq.fill(10) { + val value = dataGen.apply() + value match { + case d: Double if d.isNaN => 0.0d + case f: Float if f.isNaN => 0.0f + case _ => value + } + } + val input = inputData.map(Literal(_)) + checkEvaluation(In(input(0), input.slice(1, 10)), + inputData.slice(1, 10).contains(inputData(0))) + } } test("INSET") { @@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InSet(three, hS), false) checkEvaluation(InSet(three, nS), false) checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + + val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, + LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + primitiveTypes.map { t => + val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val inputData = Seq.fill(10) { + val value = dataGen.apply() + value match { + case d: Double if d.isNaN => 0.0d + case f: Float if f.isNaN => 0.0f + case _ => value + } + } + val input = inputData.map(Literal(_)) + checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), + inputData.slice(1, 10).contains(inputData(0))) + } } private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 1d433275fed2..6f7b5b9572e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - test("OptimizedIn test: In clause optimized to InSet") { + test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") { val originalQuery = testRelation .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery) + } + + test("OptimizedIn test: In clause optimized to InSet when more than 10 items") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_)))) + .analyze + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) + .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet)) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a43bccbe6927..e5dc676b8784 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -366,6 +366,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case expressions.InSet(a: Attribute, set) => Some(sources.In(a.name, set.toArray)) + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => + val hSet = list.map(e => e.eval(EmptyRow)) + Some(sources.In(a.name, hSet.toArray)) + case expressions.IsNull(a: Attribute) => Some(sources.IsNull(a.name)) case expressions.IsNotNull(a: Attribute) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 35ca0b4c7cc2..b35138037325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -357,6 +357,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) checkAnswer(df.filter($"b".in("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + intercept[AnalysisException] { + df2.filter($"a".in($"b")) + } } val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( From eb5b8f4a603e0f289bdaa0a2164cde2cfe4feecb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 5 Aug 2015 12:51:12 -0700 Subject: [PATCH 0066/1824] Closes #7778 since it is done as #7893. From 5f0fb6466f5e3607f7fca9b2371a73b3deef3fdf Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 5 Aug 2015 14:12:22 -0700 Subject: [PATCH 0067/1824] [SPARK-9649] Fix flaky test MasterSuite - randomize ports ``` Error Message Failed to bind to: /127.0.0.1:7093: Service 'sparkMaster' failed after 16 retries! Stacktrace java.net.BindException: Failed to bind to: /127.0.0.1:7093: Service 'sparkMaster' failed after 16 retries! at org.jboss.netty.bootstrap.ServerBootstrap.bind(ServerBootstrap.java:272) at akka.remote.transport.netty.NettyTransport$$anonfun$listen$1.apply(NettyTransport.scala:393) at akka.remote.transport.netty.NettyTransport$$anonfun$listen$1.apply(NettyTransport.scala:389) at scala.util.Success$$anonfun$map$1.apply(Try.scala:206) at scala.util.Try$.apply(Try.scala:161) ``` Author: Andrew Or Closes #7968 from andrewor14/fix-master-flaky-test and squashes the following commits: fcc42ef [Andrew Or] Randomize port --- .../org/apache/spark/deploy/master/MasterSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 30780a0da7f8..ae0e037d822e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -93,8 +93,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva publicAddress = "" ) - val (rpcEnv, uiPort, restPort) = - Master.startRpcEnvAndEndpoint("127.0.0.1", 7077, 8080, conf) + val (rpcEnv, _, _) = + Master.startRpcEnvAndEndpoint("127.0.0.1", 0, 0, conf) try { rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) @@ -343,8 +343,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva private def makeMaster(conf: SparkConf = new SparkConf): Master = { val securityMgr = new SecurityManager(conf) - val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 7077, conf, securityMgr) - val master = new Master(rpcEnv, rpcEnv.address, 8080, securityMgr, conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) master } From f9c2a2af1e883b36c5e51b87ef660a1b9ad0f586 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 5 Aug 2015 14:15:57 -0700 Subject: [PATCH 0068/1824] Closes #7474 since it's marked as won't fix. From dac090d1e9be7dec6c5ebdb2a81105b87e853193 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 5 Aug 2015 15:42:18 -0700 Subject: [PATCH 0069/1824] [SPARK-9657] Fix return type of getMaxPatternLength mengxr Author: Feynman Liang Closes #7974 from feynmanliang/SPARK-9657 and squashes the following commits: 7ca533f [Feynman Liang] Fix return type of getMaxPatternLength --- .../src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index d5f0c926c69b..ad6715b52f33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -82,7 +82,7 @@ class PrefixSpan private ( /** * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. */ - def getMaxPatternLength: Double = maxPatternLength + def getMaxPatternLength: Int = maxPatternLength /** * Sets maximal pattern length (default: `10`). From 9c878923db6634effed98c99bf24dd263bb7c6ad Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 5 Aug 2015 16:33:42 -0700 Subject: [PATCH 0070/1824] [SPARK-9054] [SQL] Rename RowOrdering to InterpretedOrdering; use newOrdering in SMJ This patches renames `RowOrdering` to `InterpretedOrdering` and updates SortMergeJoin to use the `SparkPlan` methods for constructing its ordering so that it may benefit from codegen. This is an updated version of #7408. Author: Josh Rosen Closes #7973 from JoshRosen/SPARK-9054 and squashes the following commits: e610655 [Josh Rosen] Add comment RE: Ascending ordering 34b8e0c [Josh Rosen] Import ordering be19a0f [Josh Rosen] [SPARK-9054] [SQL] Rename RowOrdering to InterpretedOrdering; use newOrdering in more places. --- .../sql/catalyst/expressions/arithmetic.scala | 4 +-- .../catalyst/expressions/conditionals.scala | 4 +-- .../{RowOrdering.scala => ordering.scala} | 27 ++++++++++--------- .../sql/catalyst/expressions/predicates.scala | 8 +++--- .../spark/sql/catalyst/util/TypeUtils.scala | 4 +-- .../apache/spark/sql/types/StructType.scala | 4 +-- .../expressions/CodeGenerationSuite.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 5 +++- .../spark/sql/execution/SparkPlan.scala | 14 ++++++++-- .../spark/sql/execution/basicOperators.scala | 4 ++- .../sql/execution/joins/SortMergeJoin.scala | 9 ++++--- .../UnsafeKVExternalSorterSuite.scala | 6 ++--- 12 files changed, 55 insertions(+), 36 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{RowOrdering.scala => ordering.scala} (85%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5808e3f66de3..98464edf4d39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -320,7 +320,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -374,7 +374,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def eval(input: InternalRow): Any = { val input1 = left.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 961b1d861680..d51f3d3cef58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -319,7 +319,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { @@ -374,7 +374,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - private lazy val ordering = TypeUtils.getOrdering(dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala similarity index 85% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 873f5324c573..6407c73bc97d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/RowOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ /** * An interpreted row ordering comparator. */ -class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { +class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) @@ -49,9 +49,9 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { case dt: AtomicType if order.direction == Descending => dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case s: StructType if order.direction == Ascending => - s.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Descending => - s.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case other => throw new IllegalArgumentException(s"Type $other does not support ordered operations") } @@ -65,6 +65,18 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { } } +object InterpretedOrdering { + + /** + * Creates a [[InterpretedOrdering]] for the given schema, in natural ascending order. + */ + def forSchema(dataTypes: Seq[DataType]): InterpretedOrdering = { + new InterpretedOrdering(dataTypes.zipWithIndex.map { + case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) + } +} + object RowOrdering { /** @@ -81,13 +93,4 @@ object RowOrdering { * Returns true iff outputs from the expressions can be ordered. */ def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType)) - - /** - * Creates a [[RowOrdering]] for the given schema, in natural ascending order. - */ - def forSchema(dataTypes: Seq[DataType]): RowOrdering = { - new RowOrdering(dataTypes.zipWithIndex.map { - case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) - }) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 68c832d7194d..fe7dffb81598 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -376,7 +376,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso override def symbol: String = "<" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } @@ -388,7 +388,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo override def symbol: String = "<=" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } @@ -400,7 +400,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar override def symbol: String = ">" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } @@ -412,7 +412,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar override def symbol: String = ">=" - private lazy val ordering = TypeUtils.getOrdering(left.dataType) + private lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 0b41f92c6193..bcf4d78fb937 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -54,10 +54,10 @@ object TypeUtils { def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] - def getOrdering(t: DataType): Ordering[Any] = { + def getInterpretedOrdering(t: DataType): Ordering[Any] = { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case s: StructType => s.ordering.asInstanceOf[Ordering[Any]] + case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 6928707f7bf6..9cbc207538d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$} /** @@ -301,7 +301,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } - private[sql] val ordering = RowOrdering.forSchema(this.fields.map(_.dataType)) + private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } object StructType extends AbstractDataType { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index cc82f7c3f5a7..e310aee22166 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -54,7 +54,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { // GenerateOrdering agrees with RowOrdering. (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => test(s"GenerateOrdering with $dataType") { - val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType)) + val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) val genOrdering = GenerateOrdering.generate( BoundReference(0, dataType, nullable = true).asc :: BoundReference(1, dataType, nullable = true).asc :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 05b009d1935b..6ea5eeedf1bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -156,7 +156,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } - implicit val ordering = new RowOrdering(sortingExpressions, child.output) + // We need to use an interpreted ordering here because generated orderings cannot be + // serialized and this ordering needs to be created on the driver in order to be passed into + // Spark core code. + implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output) new RangePartitioner(numPartitions, rddForSampling, ascending = true) case SinglePartition => new Partitioner { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index dbc0cefbe2e1..2f29067f5646 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.DataType object SparkPlan { protected[sql] val currentContext = new ThreadLocal[SQLContext]() @@ -309,13 +310,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ throw e } else { log.error("Failed to generate ordering, fallback to interpreted", e) - new RowOrdering(order, inputSchema) + new InterpretedOrdering(order, inputSchema) } } } else { - new RowOrdering(order, inputSchema) + new InterpretedOrdering(order, inputSchema) } } + /** + * Creates a row ordering for the given schema, in natural ascending order. + */ + protected def newNaturalAscendingOrdering(dataTypes: Seq[DataType]): Ordering[InternalRow] = { + val order: Seq[SortOrder] = dataTypes.zipWithIndex.map { + case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + } + newOrdering(order, Seq.empty) + } } private[sql] trait LeafNode extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 477170297c2a..f4677b4ee86b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -212,7 +212,9 @@ case class TakeOrderedAndProject( override def outputPartitioning: Partitioning = SinglePartition - private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) + // We need to use an interpreted ordering here because generated orderings cannot be serialized + // and this ordering needs to be created on the driver in order to be passed into Spark core code. + private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index eb595490fbf2..4ae23c186cf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -48,9 +48,6 @@ case class SortMergeJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) override def requiredChildOrdering: Seq[Seq[SortOrder]] = @@ -59,8 +56,10 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) + } protected override def doExecute(): RDD[InternalRow] = { val leftResults = left.execute().map(_.copy()) @@ -68,6 +67,8 @@ case class SortMergeJoin( leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => new Iterator[InternalRow] { + // An ordering that can be used to compare keys from both sides. + private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) // Mutable per row objects. private[this] val joinRow = new JoinedRow private[this] var leftElement: InternalRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 08156f0e39ce..a9515a03acf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -144,8 +144,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { } sorter.cleanupResources() - val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType)) - val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType)) + val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType)) + val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType)) val kvOrdering = new Ordering[(InternalRow, InternalRow)] { override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { keyOrdering.compare(x._1, y._1) match { From a018b85716fd510ae95a3c66d676bbdb90f8d4e7 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 5 Aug 2015 17:07:55 -0700 Subject: [PATCH 0071/1824] [SPARK-5895] [ML] Add VectorSlicer - updated Add VectorSlicer transformer to spark.ml, with features specified as either indices or names. Transfers feature attributes for selected features. Updated version of [https://github.com/apache/spark/pull/5731] CC: yinxusen This updates your PR. You'll still be the primary author of this PR. CC: mengxr Author: Xusen Yin Author: Joseph K. Bradley Closes #7972 from jkbradley/yinxusen-SPARK-5895 and squashes the following commits: b16e86e [Joseph K. Bradley] fixed scala style 71c65d2 [Joseph K. Bradley] fix import order 86e9739 [Joseph K. Bradley] cleanups per code review 9d8d6f1 [Joseph K. Bradley] style fix 83bc2e9 [Joseph K. Bradley] Updated VectorSlicer 98c6939 [Xusen Yin] fix style error ecbf2d3 [Xusen Yin] change interfaces and params f6be302 [Xusen Yin] Merge branch 'master' into SPARK-5895 e4781f2 [Xusen Yin] fix commit error fd154d7 [Xusen Yin] add test suite of vector slicer 17171f8 [Xusen Yin] fix slicer 9ab9747 [Xusen Yin] add vector slicer aa5a0bf [Xusen Yin] add vector slicer --- .../spark/ml/feature/VectorSlicer.scala | 170 ++++++++++++++++++ .../apache/spark/ml/util/MetadataUtils.scala | 17 ++ .../apache/spark/mllib/linalg/Vectors.scala | 24 +++ .../spark/ml/feature/VectorSlicerSuite.scala | 109 +++++++++++ .../spark/mllib/linalg/VectorsSuite.scala | 7 + 5 files changed, 327 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala new file mode 100644 index 000000000000..772bebeff214 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * This class takes a feature vector and outputs a new feature vector with a subarray of the + * original features. + * + * The subset of features can be specified with either indices ([[setIndices()]]) + * or names ([[setNames()]]). At least one feature must be selected. Duplicate features + * are not allowed, so there can be no overlap between selected indices and names. + * + * The output vector will order features with the selected indices first (in the order given), + * followed by the selected names (in the order given). + */ +@Experimental +final class VectorSlicer(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("vectorSlicer")) + + /** + * An array of indices to select features from a vector column. + * There can be no overlap with [[names]]. + * @group param + */ + val indices = new IntArrayParam(this, "indices", + "An array of indices to select features from a vector column." + + " There can be no overlap with names.", VectorSlicer.validIndices) + + setDefault(indices -> Array.empty[Int]) + + /** @group getParam */ + def getIndices: Array[Int] = $(indices) + + /** @group setParam */ + def setIndices(value: Array[Int]): this.type = set(indices, value) + + /** + * An array of feature names to select features from a vector column. + * These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s. + * There can be no overlap with [[indices]]. + * @group param + */ + val names = new StringArrayParam(this, "names", + "An array of feature names to select features from a vector column." + + " There can be no overlap with indices.", VectorSlicer.validNames) + + setDefault(names -> Array.empty[String]) + + /** @group getParam */ + def getNames: Array[String] = $(names) + + /** @group setParam */ + def setNames(value: Array[String]): this.type = set(names, value) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def validateParams(): Unit = { + require($(indices).length > 0 || $(names).length > 0, + s"VectorSlicer requires that at least one feature be selected.") + } + + override def transform(dataset: DataFrame): DataFrame = { + // Validity checks + transformSchema(dataset.schema) + val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) + inputAttr.numAttributes.foreach { numFeatures => + val maxIndex = $(indices).max + require(maxIndex < numFeatures, + s"Selected feature index $maxIndex invalid for only $numFeatures input features.") + } + + // Prepare output attributes + val inds = getSelectedFeatureIndices(dataset.schema) + val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs => + inds.map(index => attrs(index)) + } + val outputAttr = selectedAttrs match { + case Some(attrs) => new AttributeGroup($(outputCol), attrs) + case None => new AttributeGroup($(outputCol), inds.length) + } + + // Select features + val slicer = udf { vec: Vector => + vec match { + case features: DenseVector => Vectors.dense(inds.map(features.apply)) + case features: SparseVector => features.slice(inds) + } + } + dataset.withColumn($(outputCol), + slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata())) + } + + /** Get the feature indices in order: indices, names */ + private def getSelectedFeatureIndices(schema: StructType): Array[Int] = { + val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names)) + val indFeatures = $(indices) + val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length + lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" + + s" sets of features, but they overlap." + + s" indices: ${indFeatures.mkString("[", ",", "]")}." + + s" names: " + + nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]") + require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg) + indFeatures ++ nameFeatures + } + + override def transformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + + if (schema.fieldNames.contains($(outputCol))) { + throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.") + } + val numFeaturesSelected = $(indices).length + $(names).length + val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected) + val outputFields = schema.fields :+ outputAttr.toStructField() + StructType(outputFields) + } + + override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) +} + +private[feature] object VectorSlicer { + + /** Return true if given feature indices are valid */ + def validIndices(indices: Array[Int]): Boolean = { + if (indices.isEmpty) { + true + } else { + indices.length == indices.distinct.length && indices.forall(_ >= 0) + } + } + + /** Return true if given feature names are valid */ + def validNames(names: Array[String]): Boolean = { + names.forall(_.nonEmpty) && names.length == names.distinct.length + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index 2a1db90f2ca2..fcb517b5f735 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.util import scala.collection.immutable.HashMap import org.apache.spark.ml.attribute._ +import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.types.StructField @@ -74,4 +75,20 @@ private[spark] object MetadataUtils { } } + /** + * Takes a Vector column and a list of feature names, and returns the corresponding list of + * feature indices in the column, in order. + * @param col Vector column which must have feature names specified via attributes + * @param names List of feature names + */ + def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = { + require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col" + + s" to be Vector type, but it was type ${col.dataType} instead.") + val inputAttr = AttributeGroup.fromStructField(col) + names.map { name => + require(inputAttr.hasAttr(name), + s"getFeatureIndicesFromNames found no feature with name $name in column $col.") + inputAttr.getAttr(name).index.get + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 96d1f48ba2ba..86c461fa9163 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -766,6 +766,30 @@ class SparseVector( maxIdx } } + + /** + * Create a slice of this vector based on the given indices. + * @param selectedIndices Unsorted list of indices into the vector. + * This does NOT do bound checking. + * @return New SparseVector with values in the order specified by the given indices. + * + * NOTE: The API needs to be discussed before making this public. + * Also, if we have a version assuming indices are sorted, we should optimize it. + */ + private[spark] def slice(selectedIndices: Array[Int]): SparseVector = { + var currentIdx = 0 + val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx => + val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx) + val i_v = if (iIdx >= 0) { + Iterator((currentIdx, this.values(iIdx))) + } else { + Iterator() + } + currentIdx += 1 + i_v + }.unzip + new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray) + } } object SparseVector { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala new file mode 100644 index 000000000000..a6c2fba8360d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + val slicer = new VectorSlicer + ParamsSuite.checkParams(slicer) + assert(slicer.getIndices.length === 0) + assert(slicer.getNames.length === 0) + withClue("VectorSlicer should not have any features selected by default") { + intercept[IllegalArgumentException] { + slicer.validateParams() + } + } + } + + test("feature validity checks") { + import VectorSlicer._ + assert(validIndices(Array(0, 1, 8, 2))) + assert(validIndices(Array.empty[Int])) + assert(!validIndices(Array(-1))) + assert(!validIndices(Array(1, 2, 1))) + + assert(validNames(Array("a", "b"))) + assert(validNames(Array.empty[String])) + assert(!validNames(Array("", "b"))) + assert(!validNames(Array("a", "b", "a"))) + } + + test("Test vector slicer") { + val sqlContext = new SQLContext(sc) + + val data = Array( + Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0), + Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3), + Vectors.sparse(5, Seq()) + ) + + // Expected after selecting indices 1, 4 + val expected = Array( + Vectors.sparse(2, Seq((0, 2.3))), + Vectors.dense(2.3, 1.0), + Vectors.dense(0.0, 0.0), + Vectors.dense(-1.1, 3.3), + Vectors.sparse(2, Seq()) + ) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]]) + + val resultAttrs = Array("f1", "f4").map(defaultAttr.withName) + val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]]) + + val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) } + val df = sqlContext.createDataFrame(rdd, + StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField()))) + + val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") + + def validateResults(df: DataFrame): Unit = { + df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + assert(vec1 === vec2) + } + val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) + resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => + assert(a === b) + } + } + + vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) + validateResults(vectorSlicer.transform(df)) + + vectorSlicer.setIndices(Array(1)).setNames(Array("f4")) + validateResults(vectorSlicer.transform(df)) + + vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) + validateResults(vectorSlicer.transform(df)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 1c37ea5123e8..6508ddeba420 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -367,4 +367,11 @@ class VectorsSuite extends SparkFunSuite with Logging { val sv1c = sv1.compressed.asInstanceOf[DenseVector] assert(sv1 === sv1c) } + + test("SparseVector.slice") { + val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4)) + assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2))) + assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) + assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) + } } From 8c320e45b5c9ffd7f0e35c1c7e6b5fc355377ea6 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 5 Aug 2015 17:28:23 -0700 Subject: [PATCH 0072/1824] [SPARK-6591] [SQL] Python data source load options should auto convert common types into strings JIRA: https://issues.apache.org/jira/browse/SPARK-6591 Author: Yijie Shen Closes #7926 from yjshen/py_dsload_opt and squashes the following commits: b207832 [Yijie Shen] fix style efdf834 [Yijie Shen] resolve comment 7a8f6a2 [Yijie Shen] lowercase 822e769 [Yijie Shen] convert load opts to string --- python/pyspark/sql/readwriter.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index dea8bad79e18..bf6ac084bbbf 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -24,6 +24,16 @@ __all__ = ["DataFrameReader", "DataFrameWriter"] +def to_str(value): + """ + A wrapper over str(), but convert bool values to lower case string + """ + if isinstance(value, bool): + return str(value).lower() + else: + return str(value) + + class DataFrameReader(object): """ Interface used to load a :class:`DataFrame` from external storage systems @@ -77,7 +87,7 @@ def schema(self, schema): def option(self, key, value): """Adds an input option for the underlying data source. """ - self._jreader = self._jreader.option(key, value) + self._jreader = self._jreader.option(key, to_str(value)) return self @since(1.4) @@ -85,7 +95,7 @@ def options(self, **options): """Adds input options for the underlying data source. """ for k in options: - self._jreader = self._jreader.option(k, options[k]) + self._jreader = self._jreader.option(k, to_str(options[k])) return self @since(1.4) @@ -97,7 +107,8 @@ def load(self, path=None, format=None, schema=None, **options): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options - >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned') + >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True, + ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ From 4399b7b0903d830313ab7e69731c11d587ae567c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 5 Aug 2015 17:58:36 -0700 Subject: [PATCH 0073/1824] [SPARK-9651] Fix UnsafeExternalSorterSuite. First, it's probably a bad idea to call generated Scala methods from Java. In this case, the method being called wasn't actually "Utils.createTempDir()", but actually the method that returns the first default argument to the actual createTempDir method, which is just the location of java.io.tmpdir; meaning that all tests in the class were using the same temp dir, and thus affecting each other. Second, spillingOccursInResponseToMemoryPressure was not writing enough records to actually cause a spill. Author: Marcelo Vanzin Closes #7970 from vanzin/SPARK-9651 and squashes the following commits: 74d357f [Marcelo Vanzin] Clean up temp dir on test tear down. a64f36a [Marcelo Vanzin] [SPARK-9651] Fix UnsafeExternalSorterSuite. --- .../sort/UnsafeExternalSorterSuite.java | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 968185bde78a..117745f9a9c0 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -101,7 +101,7 @@ public OutputStream apply(OutputStream stream) { public void setUp() { MockitoAnnotations.initMocks(this); sparkConf = new SparkConf(); - tempDir = new File(Utils.createTempDir$default$1()); + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); @@ -143,13 +143,18 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th @After public void tearDown() { - long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); - if (shuffleMemoryManager != null) { - long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - shuffleMemoryManager = null; - assertEquals(0L, leakedShuffleMemory); + try { + long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (shuffleMemoryManager != null) { + long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); + shuffleMemoryManager = null; + assertEquals(0L, leakedShuffleMemory); + } + assertEquals(0, leakedUnsafeMemory); + } finally { + Utils.deleteRecursively(tempDir); + tempDir = null; } - assertEquals(0, leakedUnsafeMemory); } private void assertSpillFilesWereCleanedUp() { @@ -234,7 +239,7 @@ public void testSortingEmptyArrays() throws Exception { public void spillingOccursInResponseToMemoryPressure() throws Exception { shuffleMemoryManager = new ShuffleMemoryManager(pageSizeBytes * 2); final UnsafeExternalSorter sorter = newSorter(); - final int numRecords = 100000; + final int numRecords = (int) pageSizeBytes / 4; for (int i = 0; i <= numRecords; i++) { insertNumber(sorter, numRecords - i); } From 4581badbc8aa7e5a37ba7f7f83cc3860240f5dd3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 5 Aug 2015 19:19:09 -0700 Subject: [PATCH 0074/1824] [SPARK-9611] [SQL] Fixes a few corner cases when we spill a UnsafeFixedWidthAggregationMap This PR has the following three small fixes. 1. UnsafeKVExternalSorter does not use 0 as the initialSize to create an UnsafeInMemorySorter if its BytesToBytesMap is empty. 2. We will not not spill a InMemorySorter if it is empty. 3. We will not add a SpillReader to a SpillMerger if this SpillReader is empty. JIRA: https://issues.apache.org/jira/browse/SPARK-9611 Author: Yin Huai Closes #7948 from yhuai/unsafeEmptyMap and squashes the following commits: 9727abe [Yin Huai] Address Josh's comments. 34b6f76 [Yin Huai] 1. UnsafeKVExternalSorter does not use 0 as the initialSize to create an UnsafeInMemorySorter if its BytesToBytesMap is empty. 2. Do not spill a InMemorySorter if it is empty. 3. Do not add spill to SpillMerger if this spill is empty. --- .../unsafe/sort/UnsafeExternalSorter.java | 36 +++--- .../unsafe/sort/UnsafeSorterSpillMerger.java | 12 +- .../sql/execution/UnsafeKVExternalSorter.java | 6 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 108 +++++++++++++++++- 4 files changed, 141 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index e6ddd08e5fa9..8f78fc5a4162 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -191,24 +191,29 @@ public void spill() throws IOException { spillWriters.size(), spillWriters.size() > 1 ? " times" : " time"); - final UnsafeSorterSpillWriter spillWriter = - new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, - inMemSorter.numRecords()); - spillWriters.add(spillWriter); - final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); - while (sortedRecords.hasNext()) { - sortedRecords.loadNext(); - final Object baseObject = sortedRecords.getBaseObject(); - final long baseOffset = sortedRecords.getBaseOffset(); - final int recordLength = sortedRecords.getRecordLength(); - spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + // We only write out contents of the inMemSorter if it is not empty. + if (inMemSorter.numRecords() > 0) { + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + inMemSorter.numRecords()); + spillWriters.add(spillWriter); + final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + final int recordLength = sortedRecords.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + } + spillWriter.close(); } - spillWriter.close(); + final long spillSize = freeMemory(); // Note that this is more-or-less going to be a multiple of the page size, so wasted space in // pages will currently be counted as memory spilled even though that space isn't actually // written to disk. This also counts the space needed to store the sorter's pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + initializeForWriting(); } @@ -505,12 +510,11 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - spillMerger.addSpill(spillWriter.getReader(blockManager)); + spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); } spillWriters.clear(); - if (inMemoryIterator.hasNext()) { - spillMerger.addSpill(inMemoryIterator); - } + spillMerger.addSpillIfNotEmpty(inMemoryIterator); + return spillMerger.getSortedIterator(); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 8272c2a5be0d..3874a9f9cbdb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -47,11 +47,19 @@ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { priorityQueue = new PriorityQueue(numSpills, comparator); } - public void addSpill(UnsafeSorterIterator spillReader) throws IOException { + /** + * Add an UnsafeSorterIterator to this merger + */ + public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOException { if (spillReader.hasNext()) { + // We only add the spillReader to the priorityQueue if it is not empty. We do this to + // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator + // does not return wrong result because hasNext will returns true + // at least priorityQueue.size() times. If we allow n spillReaders in the + // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator. spillReader.loadNext(); + priorityQueue.add(spillReader); } - priorityQueue.add(spillReader); } public UnsafeSorterIterator getSortedIterator() throws IOException { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 86a563df992d..6c1cf136d9b8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -82,8 +82,11 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, pageSizeBytes); } else { // Insert the records into the in-memory sorter. + // We will use the number of elements in the map as the initialSize of the + // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize, + // we will use 1 as its initial size if the map is empty. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, map.numElements()); + taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); final int numKeyFields = keySchema.size(); BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); @@ -214,7 +217,6 @@ public boolean next() throws IOException { // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; - key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index ef827b0fe9b5..b513c970ccfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,7 +23,7 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.test.TestSQLContext @@ -231,4 +231,110 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { map.free() } + + testWithMemoryLeakDetection("test external sorting with an empty map") { + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + // Convert the map into a sorter + val sorter = map.destructAndCreateExternalSorter() + + // Add more keys to the sorter and make sure the results come out sorted. + val additionalKeys = randomStrings(1024) + val keyConverter = UnsafeProjection.create(groupKeySchema) + val valueConverter = UnsafeProjection.create(aggBufferSchema) + + additionalKeys.zipWithIndex.foreach { case (str, i) => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + val key = iter.getKey.copy() + val value = iter.getValue.copy() + assert(key.getString(0).length === value.getInt(0)) + out += key.getString(0) + } + + assert(out === (additionalKeys).sorted) + + map.free() + } + + testWithMemoryLeakDetection("test external sorting with empty records") { + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + // Memory consumption in the beginning of the task. + val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + StructType(Nil), + StructType(Nil), + taskMemoryManager, + shuffleMemoryManager, + 128, // initial capacity + PAGE_SIZE_BYTES, + false // disable perf metrics + ) + + (1 to 10).foreach { i => + val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) + assert(buf != null) + } + + // Convert the map into a sorter. Right now, it contains one record. + val sorter = map.destructAndCreateExternalSorter() + + withClue(s"destructAndCreateExternalSorter should release memory used by the map") { + // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === + initialMemoryConsumption + 4096 * 16) + } + + // Add more keys to the sorter and make sure the results come out sorted. + (1 to 4096).foreach { i => + sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + + var count = 0 + val iter = sorter.sortedIterator() + while (iter.next()) { + // At here, we also test if copy is correct. + iter.getKey.copy() + iter.getValue.copy() + count += 1; + } + + // 1 record was from the map and 4096 records were explicitly inserted. + assert(count === 4097) + + map.free() + } } From 119b59053870df7be899bf5c1c0d321406af96f9 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 6 Aug 2015 11:13:44 +0800 Subject: [PATCH 0075/1824] [SPARK-6923] [SPARK-7550] [SQL] Persists data source relations in Hive compatible format when possible This PR is a fork of PR #5733 authored by chenghao-intel. For committers who's going to merge this PR, please set the author to "Cheng Hao ". ---- When a data source relation meets the following requirements, we persist it in Hive compatible format, so that other systems like Hive can access it: 1. It's a `HadoopFsRelation` 2. It has only one input path 3. It's non-partitioned 4. It's data source provider can be naturally mapped to a Hive builtin SerDe (e.g. ORC and Parquet) Author: Cheng Lian Author: Cheng Hao Closes #7967 from liancheng/spark-6923/refactoring-pr-5733 and squashes the following commits: 5175ee6 [Cheng Lian] Fixes an oudated comment 3870166 [Cheng Lian] Fixes build error and comments 864acee [Cheng Lian] Refactors PR #5733 3490cdc [Cheng Hao] update the scaladoc 6f57669 [Cheng Hao] write schema info to hivemetastore for data source --- .../org/apache/spark/sql/DataFrame.scala | 53 +++++-- .../apache/spark/sql/DataFrameWriter.scala | 7 + .../spark/sql/hive/HiveMetastoreCatalog.scala | 146 +++++++++++++++++- .../org/apache/spark/sql/hive/HiveQl.scala | 49 ++---- .../spark/sql/hive/orc/OrcRelation.scala | 6 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 133 ++++++++++++++-- 6 files changed, 324 insertions(+), 70 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index e57acec59d32..405b5a4a9a7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,9 +20,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.util.Properties -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.unsafe.types.UTF8String - import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -42,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} -import org.apache.spark.sql.json.{JacksonGenerator, JSONRelation} +import org.apache.spark.sql.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -1650,8 +1647,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. */ @@ -1669,8 +1670,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @@ -1689,8 +1694,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. */ @@ -1709,8 +1718,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ @@ -1728,8 +1741,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. @@ -1754,8 +1771,12 @@ class DataFrame private[sql]( * an RDD out to a parquet file, and then register that file as a table. This "table" can then * be the target of an `insertInto`. * - * Also note that while this function can persist the table metadata into Hive's metastore, - * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7e3318cefe62..2a4992db09bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} +import org.apache.spark.sql.sources.HadoopFsRelation /** @@ -185,6 +186,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { * When `mode` is `Append`, the schema of the [[DataFrame]] need to be * the same as that of the existing table, and format or options will be ignored. * + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b37af99f467..1523ebe9d549 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.hive import scala.collection.JavaConversions._ +import scala.collection.mutable import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.Warehouse import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ @@ -40,9 +42,59 @@ import org.apache.spark.sql.execution.datasources import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} +private[hive] case class HiveSerDe( + inputFormat: Option[String] = None, + outputFormat: Option[String] = None, + serde: Option[String] = None) + +private[hive] object HiveSerDe { + /** + * Get the Hive SerDe information from the data source abbreviation string or classname. + * + * @param source Currently the source abbreviation can be one of the following: + * SequenceFile, RCFile, ORC, PARQUET, and case insensitive. + * @param hiveConf Hive Conf + * @return HiveSerDe associated with the specified source + */ + def sourceToSerDe(source: String, hiveConf: HiveConf): Option[HiveSerDe] = { + val serdeMap = Map( + "sequencefile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + + "rcfile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), + serde = Option(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))), + + "orc" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")), + + "parquet" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) + + val key = source.toLowerCase match { + case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet" + case _ if source.startsWith("org.apache.spark.sql.orc") => "orc" + case _ => source.toLowerCase + } + + serdeMap.get(key) + } +} + private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) extends Catalog with Logging { @@ -164,15 +216,15 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive processDatabaseAndTableName(database, tableIdent.table) } - val tableProperties = new scala.collection.mutable.HashMap[String, String] + val tableProperties = new mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) // Saves optional user specified schema. Serialized JSON schema string may be too long to be // stored into a single metastore SerDe property. In this case, we split the JSON string and // store each part as a separate SerDe property. - if (userSpecifiedSchema.isDefined) { + userSpecifiedSchema.foreach { schema => val threshold = conf.schemaStringLengthThreshold - val schemaJsonString = userSpecifiedSchema.get.json + val schemaJsonString = schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq tableProperties.put("spark.sql.sources.schema.numParts", parts.size.toString) @@ -194,7 +246,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // The table does not have a specified schema, which means that the schema will be inferred // when we load the table. So, we are not expecting partition columns and we will discover // partitions when we load the table. However, if there are specified partition columns, - // we simplily ignore them and provide a warning message.. + // we simply ignore them and provide a warning message. logWarning( s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") @@ -210,7 +262,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive ManagedTable } - client.createTable( + val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) + val dataSource = ResolvedDataSource( + hive, userSpecifiedSchema, partitionColumns, provider, options) + + def newSparkSQLSpecificMetastoreTable(): HiveTable = { HiveTable( specifiedDatabase = Option(dbName), name = tblName, @@ -218,7 +274,83 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionColumns = metastorePartitionColumns, tableType = tableType, properties = tableProperties.toMap, - serdeProperties = options)) + serdeProperties = options) + } + + def newHiveCompatibleMetastoreTable(relation: HadoopFsRelation, serde: HiveSerDe): HiveTable = { + def schemaToHiveColumn(schema: StructType): Seq[HiveColumn] = { + schema.map { field => + HiveColumn( + name = field.name, + hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), + comment = "") + } + } + + val partitionColumns = schemaToHiveColumn(relation.partitionColumns) + val dataColumns = schemaToHiveColumn(relation.schema).filterNot(partitionColumns.contains) + + HiveTable( + specifiedDatabase = Option(dbName), + name = tblName, + schema = dataColumns, + partitionColumns = partitionColumns, + tableType = tableType, + properties = tableProperties.toMap, + serdeProperties = options, + location = Some(relation.paths.head), + viewText = None, // TODO We need to place the SQL string here. + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde) + } + + // TODO: Support persisting partitioned data source relations in Hive compatible format + val hiveTable = (maybeSerDe, dataSource.relation) match { + case (Some(serde), relation: HadoopFsRelation) + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + logInfo { + "Persisting data source relation with a single input path into Hive metastore in Hive " + + s"compatible format. Input path: ${relation.paths.head}" + } + newHiveCompatibleMetastoreTable(relation, serde) + + case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => + logWarning { + val paths = relation.paths.mkString(", ") + "Persisting partitioned data source relation into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + + paths.mkString("\n", "\n", "") + } + newSparkSQLSpecificMetastoreTable() + + case (Some(serde), relation: HadoopFsRelation) => + logWarning { + val paths = relation.paths.mkString(", ") + "Persisting data source relation with multiple input paths into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + + paths.mkString("\n", "\n", "") + } + newSparkSQLSpecificMetastoreTable() + + case (Some(serde), _) => + logWarning { + s"Data source relation is not a ${classOf[HadoopFsRelation].getSimpleName}. " + + "Persisting it into Hive metastore in Spark SQL specific format, " + + "which is NOT compatible with Hive." + } + newSparkSQLSpecificMetastoreTable() + + case _ => + logWarning { + s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + + "Persisting data source relation into Hive metastore in Spark SQL specific format, " + + "which is NOT compatible with Hive." + } + newSparkSQLSpecificMetastoreTable() + } + + client.createTable(hiveTable) } def hiveDefaultTableFilePath(tableName: String): String = { @@ -463,7 +595,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p case p @ CreateTableAsSelect(table, child, allowExisting) => - val schema = if (table.schema.size > 0) { + val schema = if (table.schema.nonEmpty) { table.schema } else { child.output.map { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f43e403ce9a9..7d7b4b916730 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -261,8 +262,8 @@ private[hive] object HiveQl extends Logging { /** * Returns the HiveConf */ - private[this] def hiveConf(): HiveConf = { - val ss = SessionState.get() // SessionState is lazy initializaion, it can be null here + private[this] def hiveConf: HiveConf = { + val ss = SessionState.get() // SessionState is lazy initialization, it can be null here if (ss == null) { new HiveConf() } else { @@ -604,38 +605,18 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C serde = None, viewText = None) - // default storage type abbriviation (e.g. RCFile, ORC, PARQUET etc.) + // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbriviation - tableDesc = if ("SequenceFile".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - } else if ("RCFile".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"), - serde = Option(hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTRCFILESERDE))) - } else if ("ORC".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"), - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } else if ("PARQUET".equalsIgnoreCase(defaultStorageType)) { - tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), - serde = - Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } else { - tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } + // handle the default format for the storage type abbreviation + val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + } + + hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) + hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) + hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) children.collect { case list @ Token("TOK_TABCOLLIST", _) => @@ -908,7 +889,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, Option(hiveConf().getVar(ConfVars.HIVESCRIPTSERDE)), Nil) + case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil) } val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 6fa599734892..4a310ff4e901 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -291,9 +291,11 @@ private[orc] case class OrcTableScan( // Sets requested columns addColumnIds(attributes, relation, conf) - if (inputPaths.nonEmpty) { - FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) + if (inputPaths.isEmpty) { + // the input path probably be pruned, return an empty RDD. + return sqlContext.sparkContext.emptyRDD[InternalRow] } + FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) val inputFormatClass = classOf[OrcInputFormat] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 983c013bcf86..332c3ec0c28b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,31 +17,142 @@ package org.apache.spark.sql.hive -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.hive.test.TestHive +import java.io.File -import org.apache.spark.sql.test.ExamplePointUDT +import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.sources.DataSourceTest +import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.{Logging, SparkFunSuite} + class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { test("struct field should accept underscore in sub-column name") { - val metastr = "struct" - - val datatype = HiveMetastoreTypes.toDataType(metastr) - assert(datatype.isInstanceOf[StructType]) + val hiveTypeStr = "struct" + val dateType = HiveMetastoreTypes.toDataType(hiveTypeStr) + assert(dateType.isInstanceOf[StructType]) } test("udt to metastore type conversion") { val udt = new ExamplePointUDT - assert(HiveMetastoreTypes.toMetastoreType(udt) === - HiveMetastoreTypes.toMetastoreType(udt.sqlType)) + assertResult(HiveMetastoreTypes.toMetastoreType(udt.sqlType)) { + HiveMetastoreTypes.toMetastoreType(udt) + } } test("duplicated metastore relations") { - import TestHive.implicits._ - val df = TestHive.sql("SELECT * FROM src") + val df = sql("SELECT * FROM src") logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } } + +class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { + override val sqlContext = TestHive + + private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) + + Seq( + "parquet" -> ( + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + ), + + "orc" -> ( + "org.apache.hadoop.hive.ql.io.orc.OrcInputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat", + "org.apache.hadoop.hive.ql.io.orc.OrcSerde" + ) + ).foreach { case (provider, (inputFormat, outputFormat, serde)) => + test(s"Persist non-partitioned $provider relation into metastore as managed table") { + withTable("t") { + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .saveAsTable("t") + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(!hiveTable.isPartitioned) + assert(hiveTable.tableType === ManagedTable) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("int", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + } + } + + test(s"Persist non-partitioned $provider relation into metastore as external table") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalFile + + testDF + .write + .mode(SaveMode.Overwrite) + .format(provider) + .option("path", path.toString) + .saveAsTable("t") + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.location.get === path.toURI.toString.stripSuffix(File.separator)) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("int", "string")) + + checkAnswer(table("t"), testDF) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + } + } + } + + test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalPath + + sql( + s"""CREATE TABLE t USING $provider + |OPTIONS (path '$path') + |AS SELECT 1 AS d1, "val_1" AS d2 + """.stripMargin) + + val hiveTable = catalog.client.getTable("default", "t") + assert(hiveTable.inputFormat === Some(inputFormat)) + assert(hiveTable.outputFormat === Some(outputFormat)) + assert(hiveTable.serde === Some(serde)) + + assert(hiveTable.isPartitioned === false) + assert(hiveTable.tableType === ExternalTable) + assert(hiveTable.partitionColumns.length === 0) + + val columns = hiveTable.schema + assert(columns.map(_.name) === Seq("d1", "d2")) + assert(columns.map(_.hiveType) === Seq("int", "string")) + + checkAnswer(table("t"), Row(1, "val_1")) + assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + } + } + } + } +} From 9270bd06fd0b16892e3f37213b5bc7813ea11fdd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 5 Aug 2015 21:50:14 -0700 Subject: [PATCH 0076/1824] [SPARK-9674][SQL] Remove GeneratedAggregate. The new aggregate replaces the old GeneratedAggregate. Author: Reynold Xin Closes #7983 from rxin/remove-generated-agg and squashes the following commits: 8334aae [Reynold Xin] [SPARK-9674][SQL] Remove GeneratedAggregate. --- .../sql/execution/GeneratedAggregate.scala | 352 ------------------ .../spark/sql/execution/SparkStrategies.scala | 34 -- .../org/apache/spark/sql/SQLQuerySuite.scala | 5 +- .../spark/sql/execution/AggregateSuite.scala | 48 --- 4 files changed, 2 insertions(+), 437 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala deleted file mode 100644 index bf4905dc1eef..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ /dev/null @@ -1,352 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.io.IOException - -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.trees._ -import org.apache.spark.sql.types._ - -case class AggregateEvaluation( - schema: Seq[Attribute], - initialValues: Seq[Expression], - update: Seq[Expression], - result: Expression) - -/** - * :: DeveloperApi :: - * Alternate version of aggregation that leverages projection and thus code generation. - * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto - * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. - * @param child the input data source. - */ -@DeveloperApi -case class GeneratedAggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - unsafeEnabled: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - - protected override def doExecute(): RDD[InternalRow] = { - val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression1 => agg} - } - - // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite - // (in test "aggregation with codegen"). - val computeFunctions = aggregatesToCompute.map { - case c @ Count(expr) => - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val toCount = expr match { - case UnscaledValue(e) => e - case _ => expr - } - val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val result = currentCount - - AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case s @ Sum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(p, s) => - DecimalType.bounded(p + 10, s) - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - val updateFunction = Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType) - ) :: currentSum :: Nil) - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, s.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - - case m @ Max(expr) => - val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMax = MaxOf(currentMax, expr) - - AggregateEvaluation( - currentMax :: Nil, - initialValue :: Nil, - updateMax :: Nil, - currentMax) - - case m @ Min(expr) => - val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() - val initialValue = Literal.create(null, expr.dataType) - val updateMin = MinOf(currentMin, expr) - - AggregateEvaluation( - currentMin :: Nil, - initialValue :: Nil, - updateMin :: Nil, - currentMin) - - case CollectHashSet(Seq(expr)) => - val set = - AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() - val initialValue = NewSet(expr.dataType) - val addToSet = AddItemToSet(expr, set) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - addToSet :: Nil, - set) - - case CombineSetsAndCount(inputSet) => - val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType - val set = - AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)() - val initialValue = NewSet(elementType) - val collectSets = CombineSets(set, inputSet) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - collectSets :: Nil, - CountSet(set)) - - case o => sys.error(s"$o can't be codegened.") - } - - val computationSchema = computeFunctions.flatMap(_.schema) - - val resultMap: Map[TreeNodeRef, Expression] = - aggregatesToCompute.zip(computeFunctions).map { - case (agg, func) => new TreeNodeRef(agg) -> func.result - }.toMap - - val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne.toAttribute) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) - } - - // The set of expressions that produce the final output given the aggregation buffer and the - // grouping expressions. - val resultExpressions = aggregateExpressions.map(_.transform { - case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression => - namedGroups.collectFirst { - case (expr, attr) if expr semanticEquals e => attr - }.getOrElse(e) - }) - - val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) - - val groupKeySchema: StructType = { - val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => - // This is a dummy field name - StructField(idx.toString, expr.dataType, expr.nullable) - } - StructType(fields) - } - - val schemaSupportsUnsafe: Boolean = { - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupKeySchema) - } - - child.execute().mapPartitions { iter => - // Builds a new custom class for holding the results of aggregation for a group. - val initialValues = computeFunctions.flatMap(_.initialValues) - val newAggregationBuffer = newProjection(initialValues, child.output) - log.info(s"Initial values: ${initialValues.mkString(",")}") - - // A projection that computes the group given an input tuple. - val groupProjection = newProjection(groupingExpressions, child.output) - log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") - - // A projection that is used to update the aggregate values for a group given a new tuple. - // This projection should be targeted at the current values for the group and then applied - // to a joined row of the current values with the new input row. - val updateExpressions = computeFunctions.flatMap(_.update) - val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output - val updateProjection = newMutableProjection(updateExpressions, updateSchema)() - log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") - - // A projection that produces the final result, given a computation. - val resultProjectionBuilder = - newMutableProjection( - resultExpressions, - namedGroups.map(_._2) ++ computationSchema) - log.info(s"Result Projection: ${resultExpressions.mkString(",")}") - - val joinedRow = new JoinedRow - - if (!iter.hasNext) { - // This is an empty input, so return early so that we do not allocate data structures - // that won't be cleaned up (see SPARK-8357). - if (groupingExpressions.isEmpty) { - // This is a global aggregate, so return an empty aggregation buffer. - val resultProjection = resultProjectionBuilder() - Iterator(resultProjection(newAggregationBuffer(EmptyRow))) - } else { - // This is a grouped aggregate, so return an empty iterator. - Iterator[InternalRow]() - } - } else if (groupingExpressions.isEmpty) { - // TODO: Codegening anything other than the updateProjection is probably over kill. - val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - var currentRow: InternalRow = null - updateProjection.target(buffer) - - while (iter.hasNext) { - currentRow = iter.next() - updateProjection(joinedRow(buffer, currentRow)) - } - - val resultProjection = resultProjectionBuilder() - Iterator(resultProjection(buffer)) - - } else if (unsafeEnabled && schemaSupportsUnsafe) { - assert(iter.hasNext, "There should be at least one row for this path") - log.info("Using Unsafe-based aggregator") - val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") - val taskContext = TaskContext.get() - val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, - taskContext.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, - 1024 * 16, // initial capacity - pageSizeBytes, - false // disable tracking of performance metrics - ) - - while (iter.hasNext) { - val currentRow: InternalRow = iter.next() - val groupKey: InternalRow = groupProjection(currentRow) - val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) - if (aggregationBuffer == null) { - throw new IOException("Could not allocate memory to grow aggregation buffer") - } - updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) - } - - // Record memory used in the process - taskContext.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage) - - new Iterator[InternalRow] { - private[this] val mapIterator = aggregationMap.iterator() - private[this] val resultProjection = resultProjectionBuilder() - private[this] var _hasNext = mapIterator.next() - - def hasNext: Boolean = _hasNext - - def next(): InternalRow = { - if (_hasNext) { - val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue)) - _hasNext = mapIterator.next() - if (_hasNext) { - result - } else { - // This is the last element in the iterator, so let's free the buffer. Before we do, - // though, we need to make a defensive copy of the result so that we don't return an - // object that might contain dangling pointers to the freed memory. - val resultCopy = result.copy() - aggregationMap.free() - resultCopy - } - } else { - throw new java.util.NoSuchElementException - } - } - } - } else { - if (unsafeEnabled) { - log.info("Not using Unsafe-based aggregator because it is not supported for this schema") - } - val buffers = new java.util.HashMap[InternalRow, MutableRow]() - - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - val currentGroup = groupProjection(currentRow) - var currentBuffer = buffers.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] - buffers.put(currentGroup, currentBuffer) - } - // Target the projection at the current aggregation buffer and then project the updated - // values. - updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) - } - - new Iterator[InternalRow] { - private[this] val resultIterator = buffers.entrySet.iterator() - private[this] val resultProjection = resultProjectionBuilder() - - def hasNext: Boolean = resultIterator.hasNext - - def next(): InternalRow = { - val currentGroup = resultIterator.next() - resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) - } - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 952ba7d45c13..a730ffbb217c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -136,32 +136,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Aggregations that can be performed in two phases, before and after the shuffle. - - // Cases where all aggregates can be codegened. - case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) - if canBeCodeGened( - allAggregates(partialComputation) ++ - allAggregates(rewrittenAggregateExpressions)) && - codegenEnabled && - !canBeConvertedToNewAggregation(plan) => - execution.GeneratedAggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - unsafeEnabled, - execution.GeneratedAggregate( - partial = true, - groupingExpressions, - partialComputation, - unsafeEnabled, - planLater(child))) :: Nil - - // Cases where some aggregate can not be codegened case PartialAggregation( namedGroupingAttributes, rewrittenAggregateExpressions, @@ -192,14 +166,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => false } - def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { - case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true - // The generated set implementation is pretty limited ATM. - case CollectHashSet(exprs) if exprs.size == 1 && - Seq(IntegerType, LongType).contains(exprs.head.dataType) => true - case _ => false - } - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 29dfcf257522..cef40dd324d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.SQLTestUtils @@ -263,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.executedPlan - .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true } + .collect { case _: aggregate.Aggregate => true } .nonEmpty if (!hasGeneratedAgg) { fail( @@ -1603,7 +1602,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } - test("aggregation with codegen updates peak execution memory") { + ignore("aggregation with codegen updates peak execution memory") { withSQLConf( (SQLConf.CODEGEN_ENABLED.key, "true"), (SQLConf.USE_SQL_AGGREGATE2.key, "false")) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala deleted file mode 100644 index 20def6bef0c1..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext - -class AggregateSuite extends SparkPlanTest { - - test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") { - val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED) - val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED) - try { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true) - TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true) - val df = Seq.empty[(Int, Int)].toDF("a", "b") - checkAnswer( - df, - GeneratedAggregate( - partial = true, - Seq(df.col("b").expr), - Seq(Alias(Count(df.col("a").expr), "cnt")()), - unsafeEnabled = true, - _: SparkPlan), - Seq.empty - ) - } finally { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) - TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault) - } - } -} From d5a9af3230925c347d0904fe7f2402e468e80bc8 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 5 Aug 2015 21:50:35 -0700 Subject: [PATCH 0077/1824] [SPARK-9664] [SQL] Remove UDAFRegistration and add apply to UserDefinedAggregateFunction. https://issues.apache.org/jira/browse/SPARK-9664 Author: Yin Huai Closes #7982 from yhuai/udafRegister and squashes the following commits: 0cc2287 [Yin Huai] Remove UDAFRegistration and add apply to UserDefinedAggregateFunction. --- .../org/apache/spark/sql/SQLContext.scala | 3 -- .../apache/spark/sql/UDAFRegistration.scala | 36 ------------------- .../apache/spark/sql/UDFRegistration.scala | 16 +++++++++ .../spark/sql/execution/aggregate/udaf.scala | 8 ++--- .../apache/spark/sql/expressions/udaf.scala | 32 ++++++++++++++++- .../org/apache/spark/sql/functions.scala | 1 + .../spark/sql/hive/JavaDataFrameSuite.java | 26 ++++++++++++++ .../execution/AggregationQuerySuite.scala | 4 +-- 8 files changed, 80 insertions(+), 46 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ffc2baf7a882..6f8ffb54402a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -291,9 +291,6 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient val udf: UDFRegistration = new UDFRegistration(this) - @transient - val udaf: UDAFRegistration = new UDAFRegistration(this) - /** * Returns true if the table is currently cached in-memory. * @group cachemgmt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala deleted file mode 100644 index 0d4e30f29255..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{Expression} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction - -class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { - - private val functionRegistry = sqlContext.functionRegistry - - def register( - name: String, - func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { - def builder(children: Seq[Expression]) = ScalaUDAF(children, func) - functionRegistry.registerFunction(name, builder) - func - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 7cd7421a518c..1f270560d7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -26,6 +26,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.types.DataType /** @@ -52,6 +54,20 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { functionRegistry.registerFunction(name, udf.builder) } + /** + * Register a user-defined aggregate function (UDAF). + * @param name the name of the UDAF. + * @param udaf the UDAF needs to be registered. + * @return the registered UDAF. + */ + def register( + name: String, + udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { + def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf) + functionRegistry.registerFunction(name, builder) + udaf + } + // scalastyle:off /* register 0-22 were generated by this script diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 5fafc916bfa0..7619f3ec9f0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -316,7 +316,7 @@ private[sql] case class ScalaUDAF( override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) - private[this] val childrenSchema: StructType = { + private[this] lazy val childrenSchema: StructType = { val inputFields = children.zipWithIndex.map { case (child, index) => StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) @@ -337,16 +337,16 @@ private[sql] case class ScalaUDAF( } } - private[this] val inputToScalaConverters: Any => Any = + private[this] lazy val inputToScalaConverters: Any => Any = CatalystTypeConverters.createToScalaConverter(childrenSchema) - private[this] val bufferValuesToCatalystConverters: Array[Any => Any] = { + private[this] lazy val bufferValuesToCatalystConverters: Array[Any => Any] = { bufferSchema.fields.map { field => CatalystTypeConverters.createToCatalystConverter(field.dataType) } } - private[this] val bufferValuesToScalaConverters: Array[Any => Any] = { + private[this] lazy val bufferValuesToScalaConverters: Array[Any => Any] = { bufferSchema.fields.map { field => CatalystTypeConverters.createToScalaConverter(field.dataType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 278dd438fab4..5180871585f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.types._ import org.apache.spark.annotation.Experimental @@ -87,6 +90,33 @@ abstract class UserDefinedAggregateFunction extends Serializable { * aggregation buffer. */ def evaluate(buffer: Row): Any + + /** + * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. + */ + @scala.annotation.varargs + def apply(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression2( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = false) + Column(aggregateExpression) + } + + /** + * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. + * If `isDistinct` is true, this UDAF is working on distinct input values. + */ + @scala.annotation.varargs + def apply(isDistinct: Boolean, exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression2( + ScalaUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = isDistinct) + Column(aggregateExpression) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5a10c3891ad6..39aa905c8532 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2500,6 +2500,7 @@ object functions { * @group udf_funcs * @since 1.5.0 */ + @scala.annotation.varargs def callUDF(udfName: String, cols: Column*): Column = { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 613b2bcc80e3..21b053f07a3b 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -29,8 +29,12 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import test.org.apache.spark.sql.hive.aggregate.MyDoubleSum; public class JavaDataFrameSuite { private transient JavaSparkContext sc; @@ -77,4 +81,26 @@ public void saveTableAndQueryIt() { " ROWS BETWEEN 1 preceding and 1 following) " + "FROM window_table").collectAsList()); } + + @Test + public void testUDAF() { + DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value")); + UserDefinedAggregateFunction udaf = new MyDoubleSum(); + UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf); + // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if + // we want to use distinct aggregation. + DataFrame aggregatedDF = + df.groupBy() + .agg( + udaf.apply(true, col("value")), + udaf.apply(col("value")), + registeredUDAF.apply(col("value")), + callUDF("mydoublesum", col("value"))); + + List expectedResult = new ArrayList(); + expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); + checkAnswer( + aggregatedDF, + expectedResult); + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6f0db27775e4..4b35c8fd8353 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -73,8 +73,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be emptyDF.registerTempTable("emptyTable") // Register UDAFs - sqlContext.udaf.register("mydoublesum", new MyDoubleSum) - sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg) + sqlContext.udf.register("mydoublesum", new MyDoubleSum) + sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) } override def afterAll(): Unit = { From aead18ffca36830e854fba32a1cac11a0b2e31d5 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Thu, 6 Aug 2015 09:02:30 -0700 Subject: [PATCH 0078/1824] [SPARK-8266] [SQL] add function translate ![translate](http://www.w3resource.com/PostgreSQL/postgresql-translate-function.png) Author: zhichao.li Closes #7709 from zhichao-li/translate and squashes the following commits: 9418088 [zhichao.li] refine checking condition f2ab77a [zhichao.li] clone string 9d88f2d [zhichao.li] fix indent 6aa2962 [zhichao.li] style e575ead [zhichao.li] add python api 9d4bab0 [zhichao.li] add special case for fodable and refactor unittest eda7ad6 [zhichao.li] update to use TernaryExpression cdfd4be [zhichao.li] add function translate --- python/pyspark/sql/functions.py | 16 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/Expression.scala | 4 +- .../expressions/stringOperations.scala | 79 ++++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 14 ++++ .../org/apache/spark/sql/functions.scala | 21 +++-- .../spark/sql/StringFunctionsSuite.scala | 6 ++ .../apache/spark/unsafe/types/UTF8String.java | 16 ++++ .../spark/unsafe/types/UTF8StringSuite.java | 31 ++++++++ 9 files changed, 180 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9f0d71d7960c..b5c6a01f1885 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1290,6 +1290,22 @@ def length(col): return Column(sc._jvm.functions.length(_to_java_column(col))) +@ignore_unicode_prefix +@since(1.5) +def translate(srcCol, matching, replace): + """A function translate any character in the `srcCol` by a character in `matching`. + The characters in `replace` is corresponding to the characters in `matching`. + The translate will happen when any character in the string matching with the character + in the `matching`. + + >>> sqlContext.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\ + .alias('r')).collect() + [Row(r=u'1a2s3ae')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace)) + + # ---------------------- Collection functions ------------------------------ @since(1.4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 94c355f838fa..cd5a90d78815 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -203,6 +203,7 @@ object FunctionRegistry { expression[Substring]("substr"), expression[Substring]("substring"), expression[SubstringIndex]("substring_index"), + expression[StringTranslate]("translate"), expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ef2fc2e8c29d..0b98f555a1d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -444,7 +444,7 @@ abstract class TernaryExpression extends Expression { override def nullable: Boolean = children.exists(_.nullable) /** - * Default behavior of evaluation according to the default nullability of BinaryExpression. + * Default behavior of evaluation according to the default nullability of TernaryExpression. * If subclass of BinaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { @@ -463,7 +463,7 @@ abstract class TernaryExpression extends Expression { } /** - * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default + * Called by default [[eval]] implementation. If subclass of TernaryExpression keep the default * nullability, they can override this method to save null-check code. If we need full control * of evaluation process, we should override [[eval]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 0cc785d9f3a4..76666bd6b3d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import java.text.DecimalFormat -import java.util.{Arrays, Locale} +import java.util.Arrays +import java.util.{Map => JMap, HashMap} +import java.util.Locale import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils @@ -349,6 +351,81 @@ case class EndsWith(left: Expression, right: Expression) } } +object StringTranslate { + + def buildDict(matchingString: UTF8String, replaceString: UTF8String) + : JMap[Character, Character] = { + val matching = matchingString.toString() + val replace = replaceString.toString() + val dict = new HashMap[Character, Character]() + var i = 0 + while (i < matching.length()) { + val rep = if (i < replace.length()) replace.charAt(i) else '\0' + if (null == dict.get(matching.charAt(i))) { + dict.put(matching.charAt(i), rep) + } + i += 1 + } + dict + } +} + +/** + * A function translate any character in the `srcExpr` by a character in `replaceExpr`. + * The characters in `replaceExpr` is corresponding to the characters in `matchingExpr`. + * The translate will happen when any character in the string matching with the character + * in the `matchingExpr`. + */ +case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + @transient private var lastMatching: UTF8String = _ + @transient private var lastReplace: UTF8String = _ + @transient private var dict: JMap[Character, Character] = _ + + override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = { + if (matchingEval != lastMatching || replaceEval != lastReplace) { + lastMatching = matchingEval.asInstanceOf[UTF8String].clone() + lastReplace = replaceEval.asInstanceOf[UTF8String].clone() + dict = StringTranslate.buildDict(lastMatching, lastReplace) + } + srcEval.asInstanceOf[UTF8String].translate(dict) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastMatching = ctx.freshName("lastMatching") + val termLastReplace = ctx.freshName("lastReplace") + val termDict = ctx.freshName("dict") + val classNameDict = classOf[JMap[Character, Character]].getCanonicalName + + ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") + ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") + ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") + + nullSafeCodeGen(ctx, ev, (src, matching, replace) => { + val check = if (matchingExpr.foldable && replaceExpr.foldable) { + s"${termDict} == null" + } else { + s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" + } + s"""if ($check) { + // Not all of them is literal or matching or replace value changed + ${termLastMatching} = ${matching}.clone(); + ${termLastReplace} = ${replace}.clone(); + ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict(${termLastMatching}, ${termLastReplace}); + } + ${ev.primitive} = ${src}.translate(${termDict}); + """ + }) + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil + override def prettyName: String = "translate" +} + /** * A function that returns the index (1-based) of the given string (left) in the comma- * delimited list (right). Returns 0, if the string wasn't found or if the given diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 23f36ca43d66..426dc272471a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -431,6 +431,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(SoundEx(Literal("!!")), "!!") } + test("translate") { + checkEvaluation( + StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae") + checkEvaluation(StringTranslate(Literal("translate"), Literal(""), Literal("123")), "translate") + checkEvaluation(StringTranslate(Literal("translate"), Literal("rnlt"), Literal("")), "asae") + // test for multiple mapping + checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("123")), "12cd") + checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("12")), "12cd") + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")), "aa世b") + // scalastyle:on + } + test("TRIM/LTRIM/RTRIM") { val s = 'a.string.at(0) checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 39aa905c8532..79c5f596661d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1100,11 +1100,11 @@ object functions { } /** - * Computes hex value of the given column. - * - * @group math_funcs - * @since 1.5.0 - */ + * Computes hex value of the given column. + * + * @group math_funcs + * @since 1.5.0 + */ def hex(column: Column): Column = Hex(column.expr) /** @@ -1863,6 +1863,17 @@ object functions { def substring_index(str: Column, delim: String, count: Int): Column = SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + /* Translate any character in the src by a character in replaceString. + * The characters in replaceString is corresponding to the characters in matchingString. + * The translate will happen when any character in the string matching with the character + * in the matchingString. + * + * @group string_funcs + * @since 1.5.0 + */ + def translate(src: Column, matchingString: String, replaceString: String): Column = + StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) + /** * Trim the spaces from both ends for the specified string column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ab5da6ee79f1..ca298b243441 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -128,6 +128,12 @@ class StringFunctionsSuite extends QueryTest { // scalastyle:on } + test("string translate") { + val df = Seq(("translate", "")).toDF("a", "b") + checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae")) + checkAnswer(df.selectExpr("""translate(a, "rnlt", "")"""), Row("asae")) + } + test("string trim functions") { val df = Seq((" example ", "")).toDF("a", "b") diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index febbe3d4e54d..d1014426c0f4 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -22,6 +22,7 @@ import java.io.UnsupportedEncodingException; import java.nio.ByteOrder; import java.util.Arrays; +import java.util.Map; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -795,6 +796,21 @@ public UTF8String[] split(UTF8String pattern, int limit) { return res; } + // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes + public UTF8String translate(Map dict) { + String srcStr = this.toString(); + + StringBuilder sb = new StringBuilder(); + for(int k = 0; k< srcStr.length(); k++) { + if (null == dict.get(srcStr.charAt(k))) { + sb.append(srcStr.charAt(k)); + } else if ('\0' != dict.get(srcStr.charAt(k))){ + sb.append(dict.get(srcStr.charAt(k))); + } + } + return fromString(sb.toString()); + } + @Override public String toString() { try { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index b30c94c1c1f8..98aa8a2469a7 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -19,7 +19,9 @@ import java.io.UnsupportedEncodingException; import java.util.Arrays; +import java.util.HashMap; +import com.google.common.collect.ImmutableMap; import org.junit.Test; import static junit.framework.Assert.*; @@ -391,6 +393,35 @@ public void levenshteinDistance() { assertEquals(fromString("世界åƒä¸–").levenshteinDistance(fromString("åƒa世b")),4); } + @Test + public void translate() { + assertEquals( + fromString("1a2s3ae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '1', + 'n', '2', + 'l', '3', + 't', '\0' + ))); + assertEquals( + fromString("translate"), + fromString("translate").translate(new HashMap())); + assertEquals( + fromString("asae"), + fromString("translate").translate(ImmutableMap.of( + 'r', '\0', + 'n', '\0', + 'l', '\0', + 't', '\0' + ))); + assertEquals( + fromString("aa世b"), + fromString("花花世界").translate(ImmutableMap.of( + '花', 'a', + '界', 'b' + ))); + } + @Test public void createBlankString() { assertEquals(fromString(" "), blankString(1)); From 5b965d64ee1687145ba793da749659c8f67384e8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 09:10:57 -0700 Subject: [PATCH 0079/1824] [SPARK-9644] [SQL] Support update DecimalType with precision > 18 in UnsafeRow In order to support update a varlength (actually fixed length) object, the space should be preserved even it's null. And, we can't call setNullAt(i) for it anymore, we because setNullAt(i) will remove the offset of the preserved space, should call setDecimal(i, null, precision) instead. After this, we can do hash based aggregation on DecimalType with precision > 18. In a tests, this could decrease the end-to-end run time of aggregation query from 37 seconds (sort based) to 24 seconds (hash based). cc rxin Author: Davies Liu Closes #7978 from davies/update_decimal and squashes the following commits: bed8100 [Davies Liu] isSettable -> isMutable 923c9eb [Davies Liu] address comments and fix bug 385891d [Davies Liu] Merge branch 'master' of github.com:apache/spark into update_decimal 36a1872 [Davies Liu] fix tests cd6c524 [Davies Liu] support set decimal with precision > 18 --- .../sql/catalyst/expressions/UnsafeRow.java | 74 +++++++++++++++---- .../expressions/UnsafeRowWriters.java | 41 ++++++---- .../codegen/GenerateMutableProjection.scala | 15 +++- .../codegen/GenerateUnsafeProjection.scala | 53 +++++++------ .../spark/sql/catalyst/expressions/rows.scala | 8 +- .../expressions/UnsafeRowConverterSuite.scala | 17 ++++- .../UnsafeFixedWidthAggregationMap.java | 4 +- .../SortBasedAggregationIterator.scala | 4 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 2 +- .../spark/unsafe/PlatformDependent.java | 26 +++++++ 10 files changed, 183 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e3e1622de08b..e829acb6285f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -65,11 +65,11 @@ public static int calculateBitSetWidthInBytes(int numFields) { /** * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) */ - public static final Set settableFieldTypes; + public static final Set mutableFieldTypes; - // DecimalType(precision <= 18) is settable + // DecimalType is also mutable static { - settableFieldTypes = Collections.unmodifiableSet( + mutableFieldTypes = Collections.unmodifiableSet( new HashSet<>( Arrays.asList(new DataType[] { NullType, @@ -87,12 +87,16 @@ public static int calculateBitSetWidthInBytes(int numFields) { public static boolean isFixedLength(DataType dt) { if (dt instanceof DecimalType) { - return ((DecimalType) dt).precision() < Decimal.MAX_LONG_DIGITS(); + return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS(); } else { - return settableFieldTypes.contains(dt); + return mutableFieldTypes.contains(dt); } } + public static boolean isMutable(DataType dt) { + return mutableFieldTypes.contains(dt) || dt instanceof DecimalType; + } + ////////////////////////////////////////////////////////////////////////////// // Private fields and methods ////////////////////////////////////////////////////////////////////////////// @@ -238,17 +242,45 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } + /** + * Updates the decimal column. + * + * Note: In order to support update a decimal with precision > 18, CAN NOT call + * setNullAt() for this column. + */ @Override public void setDecimal(int ordinal, Decimal value, int precision) { assertIndexIsValid(ordinal); - if (value == null) { - setNullAt(ordinal); - } else { - if (precision <= Decimal.MAX_LONG_DIGITS()) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + // compact format + if (value == null) { + setNullAt(ordinal); + } else { setLong(ordinal, value.toUnscaledLong()); + } + } else { + // fixed length + long cursor = getLong(ordinal) >>> 32; + assert cursor > 0 : "invalid cursor " + cursor; + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L); + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L); + + if (value == null) { + setNullAt(ordinal); + // keep the offset for future update + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); } else { - // TODO(davies): support update decimal (hold a bounded space even it's null) - throw new UnsupportedOperationException(); + + final BigInteger integer = value.toJavaBigDecimal().unscaledValue(); + final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, + PlatformDependent.BIG_INTEGER_MAG_OFFSET); + assert(mag.length <= 4); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, + baseObject, baseOffset + cursor, mag.length * 4); + setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length))); } } } @@ -343,6 +375,8 @@ public double getDouble(int ordinal) { return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } + private static byte[] EMPTY = new byte[0]; + @Override public Decimal getDecimal(int ordinal, int precision, int scale) { if (isNullAt(ordinal)) { @@ -351,10 +385,20 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { return Decimal.apply(getLong(ordinal), precision, scale); } else { - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + long offsetAndSize = getLong(ordinal); + long offset = offsetAndSize >>> 32; + int signum = ((int) (offsetAndSize & 0xfff) >> 8); + assert signum >=0 && signum <= 2 : "invalid signum " + signum; + int size = (int) (offsetAndSize & 0xff); + int[] mag = new int[size]; + PlatformDependent.copyMemory(baseObject, baseOffset + offset, + mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4); + + // create a BigInteger using signum and mag + BigInteger v = new BigInteger(0, EMPTY); // create the initial object + PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); + PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag); + return Decimal.apply(new BigDecimal(v, scale), precision, scale); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 31928731545d..28e7ec0a0f12 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions; +import java.math.BigInteger; + import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.MapData; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -47,29 +48,41 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input /** Writer for Decimal with precision larger than 18. */ public static class DecimalWriter { - + private static final int SIZE = 16; public static int getSize(Decimal input) { // bounded size - return 16; + return SIZE; } public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + final Object base = target.getBaseObject(); final long offset = target.getBaseOffset() + cursor; - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - final int numBytes = bytes.length; - assert(numBytes <= 16); - // zero-out the bytes - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L); - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L); + PlatformDependent.UNSAFE.putLong(base, offset, 0L); + PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L); + + if (input == null) { + target.setNullAt(ordinal); + // keep the offset and length for update + int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8; + PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset, + ((long) cursor) << 32); + return SIZE; + } - // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - target.getBaseObject(), offset, numBytes); + final BigInteger integer = input.toJavaBigDecimal().unscaledValue(); + int signum = integer.signum() + 1; + final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, + PlatformDependent.BIG_INTEGER_MAG_OFFSET); + assert(mag.length <= 4); + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, + base, target.getBaseOffset() + cursor, mag.length * 4); // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - return 16; + target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length))); + + return SIZE; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e4a8fc24dac2..ac58423cd884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.types.DecimalType // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -43,14 +44,26 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) - evaluationCode.code + + if (e.dataType.isInstanceOf[DecimalType]) { + // Can't call setNullAt on DecimalType, because we need to keep the offset s""" + ${evaluationCode.code} + if (${evaluationCode.isNull}) { + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + } + """ + } else { + s""" + ${evaluationCode.code} if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); } else { ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; } """ + } } // collect projections into blocks as function has 64kb codesize limit in JVM val projectionBlocks = new ArrayBuffer[String]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 71f8ea09f077..d8912df694a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -45,10 +45,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { + case NullType => true case t: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) - case NullType => true case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true case _ => false @@ -56,7 +56,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))" + s" + $DecimalWriter.getSize(${ev.primitive})" case StringType => s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" case BinaryType => @@ -76,41 +76,41 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodeGenContext, fieldType: DataType, ev: GeneratedExpressionCode, - primitive: String, + target: String, index: Int, cursor: String): String = fieldType match { case _ if ctx.isPrimitiveType(fieldType) => - s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}" + s"${ctx.setColumn(target, fieldType, index, ev.primitive)}" case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" // make sure Decimal object has the same scale as DecimalType if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + $CompactDecimalWriter.write($target, $index, $cursor, ${ev.primitive}); } else { - $primitive.setNullAt($index); + $target.setNullAt($index); } """ case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => s""" // make sure Decimal object has the same scale as DecimalType if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + $cursor += $DecimalWriter.write($target, $index, $cursor, ${ev.primitive}); } else { - $primitive.setNullAt($index); + $cursor += $DecimalWriter.write($target, $index, $cursor, null); } """ case StringType => - s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $StringWriter.write($target, $index, $cursor, ${ev.primitive})" case BinaryType => - s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $BinaryWriter.write($target, $index, $cursor, ${ev.primitive})" case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $IntervalWriter.write($target, $index, $cursor, ${ev.primitive})" case _: StructType => - s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $StructWriter.write($target, $index, $cursor, ${ev.primitive})" case _: ArrayType => - s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $ArrayWriter.write($target, $index, $cursor, ${ev.primitive})" case _: MapType => - s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})" + s"$cursor += $MapWriter.write($target, $index, $cursor, ${ev.primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") @@ -146,13 +146,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => val update = genFieldWriter(ctx, dt, ev, output, i, cursor) - s""" - if (${ev.isNull}) { - $output.setNullAt($i); - } else { - $update; - } - """ + if (dt.isInstanceOf[DecimalType]) { + // Can't call setNullAt() for DecimalType + s""" + if (${ev.isNull}) { + $cursor += $DecimalWriter.write($output, $i, $cursor, null); + } else { + $update; + } + """ + } else { + s""" + if (${ev.isNull}) { + $output.setNullAt($i); + } else { + $update; + } + """ + } }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 5e5de1d1dc6a..7657fb535dcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. @@ -39,6 +38,13 @@ abstract class MutableRow extends InternalRow { def setLong(i: Int, value: Long): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } def setDouble(i: Int, value: Double): Unit = { update(i, value) } + + /** + * Update the decimal column at `i`. + * + * Note: In order to support update decimal with precision > 18 in UnsafeRow, + * CAN NOT call setNullAt() for decimal column on UnsafeRow, call setDecimal(i, null, precision). + */ def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 59491c5ba160..8c7220319363 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -123,7 +123,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DoubleType, StringType, BinaryType, - DecimalType.USER_DEFAULT + DecimalType.USER_DEFAULT, + DecimalType.SYSTEM_DEFAULT // ArrayType(IntegerType) ) val converter = UnsafeProjection.create(fieldTypes) @@ -151,6 +152,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) assert(createdFromNull.getDecimal(10, 10, 0) === null) + assert(createdFromNull.getDecimal(11, 38, 18) === null) // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those @@ -169,6 +171,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) r.setDecimal(10, Decimal(10), 10) + r.setDecimal(11, Decimal(10.00, 38, 18), 38) // r.update(11, Array(11)) r } @@ -187,10 +190,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) assert(setToNullAfterCreation.getDecimal(10, 10, 0) === rowWithNoNullColumns.getDecimal(10, 10, 0)) + assert(setToNullAfterCreation.getDecimal(11, 38, 18) === + rowWithNoNullColumns.getDecimal(11, 38, 18)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- fieldTypes.indices) { - setToNullAfterCreation.setNullAt(i) + // Cann't call setNullAt() on DecimalType + if (i == 11) { + setToNullAfterCreation.setDecimal(11, null, 38) + } else { + setToNullAfterCreation.setNullAt(i) + } } // There are some garbage left in the var-length area assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes())) @@ -206,6 +216,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) // setToNullAfterCreation.update(9, "world".getBytes) setToNullAfterCreation.setDecimal(10, Decimal(10), 10) + setToNullAfterCreation.setDecimal(11, Decimal(10.00, 38, 18), 38) // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) @@ -220,6 +231,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) assert(setToNullAfterCreation.getDecimal(10, 10, 0) === rowWithNoNullColumns.getDecimal(10, 10, 0)) + assert(setToNullAfterCreation.getDecimal(11, 38, 18) === + rowWithNoNullColumns.getDecimal(11, 38, 18)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 43d06ce9bdfa..02458030b00e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -72,7 +72,7 @@ public final class UnsafeFixedWidthAggregationMap { */ public static boolean supportsAggregationBufferSchema(StructType schema) { for (StructField field: schema.fields()) { - if (!UnsafeRow.isFixedLength(field.dataType())) { + if (!UnsafeRow.isMutable(field.dataType())) { return false; } } @@ -111,8 +111,6 @@ public UnsafeFixedWidthAggregationMap( // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); - assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 + - UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length())); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 78bcee16c9d0..40f6bff53d2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} -import org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap -import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator /** @@ -57,7 +55,7 @@ class SortBasedAggregationIterator( val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength) + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) val buffer = if (useUnsafeBuffer) { val unsafeProjection = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index b513c970ccfe..e03473041c3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -93,7 +93,7 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { testWithMemoryLeakDetection("supported schemas") { assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) - assert(!supportsAggregationBufferSchema( + assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) assert( diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java index 192c6714b240..b2de2a2590f0 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe; import java.lang.reflect.Field; +import java.math.BigInteger; import sun.misc.Unsafe; @@ -87,6 +88,14 @@ public static void putDouble(Object object, long offset, double value) { _UNSAFE.putDouble(object, offset, value); } + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } + + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } + public static long allocateMemory(long size) { return _UNSAFE.allocateMemory(size); } @@ -107,6 +116,10 @@ public static void freeMemory(long address) { public static final int DOUBLE_ARRAY_OFFSET; + // Support for resetting final fields while deserializing + public static final long BIG_INTEGER_SIGNUM_OFFSET; + public static final long BIG_INTEGER_MAG_OFFSET; + /** * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to * allow safepoint polling during a large copy. @@ -129,11 +142,24 @@ public static void freeMemory(long address) { INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); + + long signumOffset = 0; + long magOffset = 0; + try { + signumOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("signum")); + magOffset = _UNSAFE.objectFieldOffset(BigInteger.class.getDeclaredField("mag")); + } catch (Exception ex) { + // should not happen + } + BIG_INTEGER_SIGNUM_OFFSET = signumOffset; + BIG_INTEGER_MAG_OFFSET = magOffset; } else { BYTE_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; LONG_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; + BIG_INTEGER_SIGNUM_OFFSET = 0; + BIG_INTEGER_MAG_OFFSET = 0; } } From 93085c992e40dbc06714cb1a64c838e25e683a6f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 09:12:41 -0700 Subject: [PATCH 0080/1824] [SPARK-9482] [SQL] Fix thread-safey issue of using UnsafeProjection in join This PR also change to use `def` instead of `lazy val` for UnsafeProjection, because it's not thread safe. TODO: cleanup the debug code once the flaky test passed 100 times. Author: Davies Liu Closes #7940 from davies/semijoin and squashes the following commits: 93baac7 [Davies Liu] fix outerjoin 5c40ded [Davies Liu] address comments aa3de46 [Davies Liu] Merge branch 'master' of github.com:apache/spark into semijoin 7590a25 [Davies Liu] Merge branch 'master' of github.com:apache/spark into semijoin 2d4085b [Davies Liu] use def for resultProjection 0833407 [Davies Liu] Merge branch 'semijoin' of github.com:davies/spark into semijoin e0d8c71 [Davies Liu] use lazy val 6a59e8f [Davies Liu] Update HashedRelation.scala 0fdacaf [Davies Liu] fix broadcast and thread-safety of UnsafeProjection 2fc3ef6 [Davies Liu] reproduce failure in semijoin --- .../execution/joins/BroadcastHashJoin.scala | 6 ++--- .../joins/BroadcastHashOuterJoin.scala | 20 ++++++---------- .../joins/BroadcastNestedLoopJoin.scala | 17 ++++++++------ .../spark/sql/execution/joins/HashJoin.scala | 4 ++-- .../sql/execution/joins/HashOuterJoin.scala | 23 ++++++++++--------- .../sql/execution/joins/HashSemiJoin.scala | 8 +++---- .../sql/execution/joins/HashedRelation.scala | 4 ++-- .../joins/ShuffledHashOuterJoin.scala | 6 +++-- 8 files changed, 44 insertions(+), 44 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index ec1a148342fc..f7a68e4f5d44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} import org.apache.spark.util.ThreadUtils +import org.apache.spark.{InternalAccumulator, TaskContext} /** * :: DeveloperApi :: @@ -102,6 +102,6 @@ case class BroadcastHashJoin( object BroadcastHashJoin { - private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( + private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index e342fd914d32..a3626de49aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,15 +20,14 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ -import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.{InternalAccumulator, TaskContext} /** * :: DeveloperApi :: @@ -76,7 +75,7 @@ case class BroadcastHashOuterJoin( val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) } - }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) + }(BroadcastHashJoin.broadcastHashJoinExecutionContext) } protected override def doPrepare(): Unit = { @@ -98,19 +97,20 @@ case class BroadcastHashOuterJoin( case _ => } + val resultProj = resultProjection joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey)) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj) }) case RightOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj) }) case x => @@ -120,9 +120,3 @@ case class BroadcastHashOuterJoin( } } } - -object BroadcastHashOuterJoin { - - private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128)) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 83b726a8e289..23aebf4b068b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -47,7 +47,7 @@ case class BroadcastNestedLoopJoin( override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true - @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { + private[this] def genResultProjection: InternalRow => InternalRow = { if (outputsUnsafeRows) { UnsafeProjection.create(schema) } else { @@ -88,6 +88,7 @@ case class BroadcastNestedLoopJoin( val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) + val resultProj = genResultProjection streamedIter.foreach { streamedRow => var i = 0 @@ -97,11 +98,11 @@ case class BroadcastNestedLoopJoin( val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy() + matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy() + matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case _ => @@ -111,9 +112,9 @@ case class BroadcastNestedLoopJoin( (streamRowMatched, joinType, buildSide) match { case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy() + matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy() + matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() case _ => } } @@ -127,6 +128,8 @@ case class BroadcastNestedLoopJoin( val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) + val resultProj = genResultProjection + /** Rows from broadcasted joined with nulls. */ val broadcastRowsWithNulls: Seq[InternalRow] = { val buf: CompactBuffer[InternalRow] = new CompactBuffer() @@ -138,7 +141,7 @@ case class BroadcastNestedLoopJoin( joinedRow.withLeft(leftNulls) while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { - buf += resultProjection(joinedRow.withRight(rel(i))).copy() + buf += resultProj(joinedRow.withRight(rel(i))).copy() } i += 1 } @@ -147,7 +150,7 @@ case class BroadcastNestedLoopJoin( joinedRow.withRight(rightNulls) while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { - buf += resultProjection(joinedRow.withLeft(rel(i))).copy() + buf += resultProj(joinedRow.withLeft(rel(i))).copy() } i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 6b3d1652923f..5e9cd9fd2345 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -52,14 +52,14 @@ trait HashJoin { override def canProcessUnsafeRows: Boolean = isUnsafeMode override def canProcessSafeRows: Boolean = !isUnsafeMode - @transient protected lazy val buildSideKeyGenerator: Projection = + protected def buildSideKeyGenerator: Projection = if (isUnsafeMode) { UnsafeProjection.create(buildKeys, buildPlan.output) } else { newMutableProjection(buildKeys, buildPlan.output)() } - @transient protected lazy val streamSideKeyGenerator: Projection = + protected def streamSideKeyGenerator: Projection = if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index a323aea4ea2c..346337e64245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -76,14 +76,14 @@ trait HashOuterJoin { override def canProcessUnsafeRows: Boolean = isUnsafeMode override def canProcessSafeRows: Boolean = !isUnsafeMode - @transient protected lazy val buildKeyGenerator: Projection = + protected def buildKeyGenerator: Projection = if (isUnsafeMode) { UnsafeProjection.create(buildKeys, buildPlan.output) } else { newMutableProjection(buildKeys, buildPlan.output)() } - @transient protected[this] lazy val streamedKeyGenerator: Projection = { + protected[this] def streamedKeyGenerator: Projection = { if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { @@ -91,7 +91,7 @@ trait HashOuterJoin { } } - @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { + protected[this] def resultProjection: InternalRow => InternalRow = { if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { @@ -113,7 +113,8 @@ trait HashOuterJoin { protected[this] def leftOuterIterator( key: InternalRow, joinedRow: JoinedRow, - rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { + rightIter: Iterable[InternalRow], + resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (rightIter != null) { @@ -124,12 +125,12 @@ trait HashOuterJoin { List.empty } if (temp.isEmpty) { - resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } else { temp } } else { - resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } } ret.iterator @@ -138,24 +139,24 @@ trait HashOuterJoin { protected[this] def rightOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], - joinedRow: JoinedRow): Iterator[InternalRow] = { + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (leftIter != null) { leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => - resultProjection(joinedRow).copy() + case l if boundCondition(joinedRow.withLeft(l)) => resultProjection(joinedRow).copy() } } else { List.empty } if (temp.isEmpty) { - resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } else { temp } } else { - resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } } ret.iterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 97fde8f975bf..47a7d370f541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -43,14 +43,14 @@ trait HashSemiJoin { override def canProcessUnsafeRows: Boolean = supportUnsafe override def canProcessSafeRows: Boolean = !supportUnsafe - @transient protected lazy val leftKeyGenerator: Projection = + protected def leftKeyGenerator: Projection = if (supportUnsafe) { UnsafeProjection.create(leftKeys, left.output) } else { newMutableProjection(leftKeys, left.output)() } - @transient protected lazy val rightKeyGenerator: Projection = + protected def rightKeyGenerator: Projection = if (supportUnsafe) { UnsafeProjection.create(rightKeys, right.output) } else { @@ -62,12 +62,11 @@ trait HashSemiJoin { protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() - var currentRow: InternalRow = null // Create a Hash set of buildKeys val rightKey = rightKeyGenerator while (buildIter.hasNext) { - currentRow = buildIter.next() + val currentRow = buildIter.next() val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) @@ -76,6 +75,7 @@ trait HashSemiJoin { } } } + hashSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 58b4236f7b5b..3f257ecdd156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.joins -import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput} +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} import org.apache.spark.shuffle.ShuffleMemoryManager -import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -31,6 +30,7 @@ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.{SparkConf, SparkEnv} /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index eee8ad800f98..6a8c35efca8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -60,19 +60,21 @@ case class ShuffledHashOuterJoin( case LeftOuter => val hashed = HashedRelation(rightIter, buildKeyGenerator) val keyGenerator = streamedKeyGenerator + val resultProj = resultProjection leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj) }) case RightOuter => val hashed = HashedRelation(leftIter, buildKeyGenerator) val keyGenerator = streamedKeyGenerator + val resultProj = resultProjection rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj) }) case FullOuter => From 9f94c85ff35df6289371f80edde51c2aa6c4bcdc Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 6 Aug 2015 09:53:53 -0700 Subject: [PATCH 0081/1824] [SPARK-9593] [SQL] [HOTFIX] Makes the Hadoop shims loading fix more robust This is a follow-up of #7929. We found that Jenkins SBT master build still fails because of the Hadoop shims loading issue. But the failure doesn't appear to be deterministic. My suspect is that Hadoop `VersionInfo` class may fail to inspect Hadoop version, and the shims loading branch is skipped. This PR tries to make the fix more robust: 1. When Hadoop version is available, we load `Hadoop20SShims` for versions <= 2.0.x as srowen suggested in PR #7929. 2. Otherwise, we use `Path.getPathWithoutSchemeAndAuthority` as a probe method, which doesn't exist in Hadoop 1.x or 2.0.x. If this method is not found, `Hadoop20SShims` is also loaded. Author: Cheng Lian Closes #7994 from liancheng/spark-9593/fix-hadoop-shims and squashes the following commits: e1d3d70 [Cheng Lian] Fixes typo in comments 8d971da [Cheng Lian] Makes the Hadoop shims loading fix more robust --- .../spark/sql/hive/client/ClientWrapper.scala | 88 ++++++++++++------- 1 file changed, 55 insertions(+), 33 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 211a3b879c1b..3d05b583cf9e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -68,45 +68,67 @@ private[hive] class ClientWrapper( // !! HACK ALERT !! // - // This method is a surgical fix for Hadoop version 2.0.0-mr1-cdh4.1.1, which is used by Spark EC2 - // scripts. We should remove this after upgrading Spark EC2 scripts to some more recent Hadoop - // version in the future. - // // Internally, Hive `ShimLoader` tries to load different versions of Hadoop shims by checking - // version information gathered from Hadoop jar files. If the major version number is 1, - // `Hadoop20SShims` will be loaded. Otherwise, if the major version number is 2, `Hadoop23Shims` - // will be chosen. + // major version number gathered from Hadoop jar files: + // + // - For major version number 1, load `Hadoop20SShims`, where "20S" stands for Hadoop 0.20 with + // security. + // - For major version number 2, load `Hadoop23Shims`, where "23" stands for Hadoop 0.23. // - // However, part of APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical - // reasons. So 2.0.0-mr1-cdh4.1.1 is actually more Hadoop-1-like and should be used together with - // `Hadoop20SShims`, but `Hadoop20SShims` is chose because the major version number here is 2. + // However, APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical reasons. It + // turns out that Hadoop 2.0.x versions should also be used together with `Hadoop20SShims`, but + // `Hadoop23Shims` is chosen because the major version number here is 2. // - // Here we check for this specific version and loads `Hadoop20SShims` via reflection. Note that - // we can't check for string literal "2.0.0-mr1-cdh4.1.1" because the obtained version string - // comes from Maven artifact org.apache.hadoop:hadoop-common:2.0.0-cdh4.1.1, which doesn't have - // the "mr1" tag in its version string. + // To fix this issue, we try to inspect Hadoop version via `org.apache.hadoop.utils.VersionInfo` + // and load `Hadoop20SShims` for Hadoop 1.x and 2.0.x versions. If Hadoop version information is + // not available, we decide whether to override the shims or not by checking for existence of a + // probe method which doesn't exist in Hadoop 1.x or 2.0.x versions. private def overrideHadoopShims(): Unit = { - val VersionPattern = """2\.0\.0.*cdh4.*""".r - - VersionInfo.getVersion match { - case VersionPattern() => - val shimClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" - logInfo(s"Loading Hadoop shims $shimClassName") - - try { - val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") - // scalastyle:off classforname - val shimsClass = Class.forName(shimClassName) - // scalastyle:on classforname - val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) - shimsField.setAccessible(true) - shimsField.set(null, shims) - } catch { case cause: Throwable => - logError(s"Failed to load $shimClassName") - // Falls back to normal Hive `ShimLoader` logic + val hadoopVersion = VersionInfo.getVersion + val VersionPattern = """(\d+)\.(\d+).*""".r + + hadoopVersion match { + case null => + logError("Failed to inspect Hadoop version") + + // Using "Path.getPathWithoutSchemeAndAuthority" as the probe method. + val probeMethod = "getPathWithoutSchemeAndAuthority" + if (!classOf[Path].getDeclaredMethods.exists(_.getName == probeMethod)) { + logInfo( + s"Method ${classOf[Path].getCanonicalName}.$probeMethod not found, " + + s"we are probably using Hadoop 1.x or 2.0.x") + loadHadoop20SShims() + } + + case VersionPattern(majorVersion, minorVersion) => + logInfo(s"Inspected Hadoop version: $hadoopVersion") + + // Loads Hadoop20SShims for 1.x and 2.0.x versions + val (major, minor) = (majorVersion.toInt, minorVersion.toInt) + if (major < 2 || (major == 2 && minor == 0)) { + loadHadoop20SShims() } + } + + // Logs the actual loaded Hadoop shims class + val loadedShimsClassName = ShimLoader.getHadoopShims.getClass.getCanonicalName + logInfo(s"Loaded $loadedShimsClassName for Hadoop version $hadoopVersion") + } - case _ => + private def loadHadoop20SShims(): Unit = { + val hadoop20SShimsClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" + logInfo(s"Loading Hadoop shims $hadoop20SShimsClassName") + + try { + val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") + // scalastyle:off classforname + val shimsClass = Class.forName(hadoop20SShimsClassName) + // scalastyle:on classforname + val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) + shimsField.setAccessible(true) + shimsField.set(null, shims) + } catch { case cause: Throwable => + throw new RuntimeException(s"Failed to load $hadoop20SShimsClassName", cause) } } From c5c6aded641048a3e66ac79d9e84d34e4b1abae7 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 6 Aug 2015 10:08:33 -0700 Subject: [PATCH 0082/1824] [SPARK-9112] [ML] Implement Stats for LogisticRegression I have added support for stats in LogisticRegression. The API is similar to that of LinearRegression with LogisticRegressionTrainingSummary and LogisticRegressionSummary I have some queries and asked them inline. Author: MechCoder Closes #7538 from MechCoder/log_reg_stats and squashes the following commits: 2e9f7c7 [MechCoder] Change defs into lazy vals d775371 [MechCoder] Clean up class inheritance 9586125 [MechCoder] Add abstraction to handle Multiclass Metrics 40ad8ef [MechCoder] minor 640376a [MechCoder] remove unnecessary dataframe stuff and add docs 80d9954 [MechCoder] Added tests fbed861 [MechCoder] DataFrame support for metrics 70a0fc4 [MechCoder] [SPARK-9112] [ML] Implement Stats for LogisticRegression --- .../classification/LogisticRegression.scala | 166 +++++++++++++++++- .../JavaLogisticRegressionSuite.java | 9 + .../LogisticRegressionSuite.scala | 37 +++- 3 files changed, 209 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 0d073839259c..f55134d25885 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -30,10 +30,12 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.storage.StorageLevel /** @@ -284,7 +286,13 @@ class LogisticRegression(override val uid: String) if (handlePersistence) instances.unpersist() - copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val model = copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val logRegSummary = new BinaryLogisticRegressionTrainingSummary( + model.transform(dataset), + $(probabilityCol), + $(labelCol), + objectiveHistory) + model.setSummary(logRegSummary) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) @@ -319,6 +327,38 @@ class LogisticRegressionModel private[ml] ( override val numClasses: Int = 2 + private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + def summary: LogisticRegressionTrainingSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this LogisticRegressionModel", + new NullPointerException()) + } + + private[classification] def setSummary( + summary: LogisticRegressionTrainingSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** Indicates whether a training summary exists for this model instance. */ + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Evaluates the model on a testset. + * @param dataset Test dataset to evaluate model on. + */ + // TODO: decide on a good name before exposing to public API + private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol)) + } + /** * Predict label for the given feature vector. * The behavior of this can be adjusted using [[thresholds]]. @@ -440,6 +480,128 @@ private[classification] class MultiClassSummarizer extends Serializable { } } +/** + * Abstraction for multinomial Logistic Regression Training results. + */ +sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { + + /** objective function (scaled loss + regularization) at each iteration. */ + def objectiveHistory: Array[Double] + + /** Number of training iterations until termination */ + def totalIterations: Int = objectiveHistory.length + +} + +/** + * Abstraction for Logistic Regression Results for a given model. + */ +sealed trait LogisticRegressionSummary extends Serializable { + + /** Dataframe outputted by the model's `transform` method. */ + def predictions: DataFrame + + /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */ + def probabilityCol: String + + /** Field in "predictions" which gives the the true label of each sample. */ + def labelCol: String + +} + +/** + * :: Experimental :: + * Logistic regression training results. + * @param predictions dataframe outputted by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the calibrated probability of + * each sample as a vector. + * @param labelCol field in "predictions" which gives the true label of each sample. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +@Experimental +class BinaryLogisticRegressionTrainingSummary private[classification] ( + predictions: DataFrame, + probabilityCol: String, + labelCol: String, + val objectiveHistory: Array[Double]) + extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol) + with LogisticRegressionTrainingSummary { + +} + +/** + * :: Experimental :: + * Binary Logistic regression results for a given model. + * @param predictions dataframe outputted by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the calibrated probability of + * each sample. + * @param labelCol field in "predictions" which gives the true label of each sample. + */ +@Experimental +class BinaryLogisticRegressionSummary private[classification] ( + @transient override val predictions: DataFrame, + override val probabilityCol: String, + override val labelCol: String) extends LogisticRegressionSummary { + + private val sqlContext = predictions.sqlContext + import sqlContext.implicits._ + + /** + * Returns a BinaryClassificationMetrics object. + */ + // TODO: Allow the user to vary the number of bins using a setBins method in + // BinaryClassificationMetrics. For now the default is set to 100. + @transient private val binaryMetrics = new BinaryClassificationMetrics( + predictions.select(probabilityCol, labelCol).map { + case Row(score: Vector, label: Double) => (score(1), label) + }, 100 + ) + + /** + * Returns the receiver operating characteristic (ROC) curve, + * which is an Dataframe having two fields (FPR, TPR) + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + */ + @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") + + /** + * Computes the area under the receiver operating characteristic (ROC) curve. + */ + lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() + + /** + * Returns the precision-recall curve, which is an Dataframe containing + * two fields recall, precision with (0.0, 1.0) prepended to it. + */ + @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") + + /** + * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + */ + @transient lazy val fMeasureByThreshold: DataFrame = { + binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") + } + + /** + * Returns a dataframe with two fields (threshold, precision) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the precision. + */ + @transient lazy val precisionByThreshold: DataFrame = { + binaryMetrics.precisionByThreshold().toDF("threshold", "precision") + } + + /** + * Returns a dataframe with two fields (threshold, recall) curve. + * Every possible probability obtained in transforming the dataset are used + * as thresholds used in calculating the recall. + */ + @transient lazy val recallByThreshold: DataFrame = { + binaryMetrics.recallByThreshold().toDF("threshold", "recall") + } +} + /** * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used * in binary classification for samples in sparse or dense vector in a online fashion. diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index fb1de51163f2..7e9aa383728f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -152,4 +152,13 @@ public void logisticRegressionPredictorClassifierMethods() { } } } + + @Test + public void logisticRegressionTrainingSummary() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + + LogisticRegressionTrainingSummary summary = model.summary(); + assert(summary.totalIterations() == summary.objectiveHistory().length); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index da13dcb42d1c..8c3d4590f5ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -723,6 +723,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptR relTol 1E-5) - assert(model1.weights ~= weightsR absTol 1E-6) + assert(model1.weights ~== weightsR absTol 1E-6) + } + + test("evaluate on test set") { + // Evaluate on test set should be same as that of the transformed training data. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary] + + val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary] + assert(summary.areaUnderROC === sameSummary.areaUnderROC) + assert(summary.roc.collect() === sameSummary.roc.collect()) + assert(summary.pr.collect === sameSummary.pr.collect()) + assert( + summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect()) + assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect()) + assert( + summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) + } + + test("statistics on training data") { + // Test that loss is monotonically decreasing. + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setThreshold(0.6) + val model = lr.fit(dataset) + assert( + model.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + } } From 076ec056818a65216eaf51aa5b3bd8f697c34748 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 6 Aug 2015 10:09:58 -0700 Subject: [PATCH 0083/1824] [SPARK-9533] [PYSPARK] [ML] Add missing methods in Word2Vec ML After https://github.com/apache/spark/pull/7263 it is pretty straightforward to Python wrappers. Author: MechCoder Closes #7930 from MechCoder/spark-9533 and squashes the following commits: 1bea394 [MechCoder] make getVectors a lazy val 5522756 [MechCoder] [SPARK-9533] [PySpark] [ML] Add missing methods in Word2Vec ML --- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- python/pyspark/ml/feature.py | 40 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index b4f46cef798d..29acc3eb5865 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -153,7 +153,7 @@ class Word2VecModel private[ml] ( * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and * and the vector the DenseVector that it is mapped to. */ - val getVectors: DataFrame = { + @transient lazy val getVectors: DataFrame = { val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3f04c41ac5ab..cb4dfa21298c 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,11 +15,16 @@ # limitations under the License. # +import sys +if sys.version > '3': + basestring = str + from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer from pyspark.mllib.common import inherit_doc +from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', @@ -954,6 +959,23 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has >>> sent = ("a b " * 100 + "a c " * 10).split(" ") >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"]) >>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc) + >>> model.getVectors().show() + +----+--------------------+ + |word| vector| + +----+--------------------+ + | a|[-0.3511952459812...| + | b|[0.29077222943305...| + | c|[0.02315592765808...| + +----+--------------------+ + ... + >>> model.findSynonyms("a", 2).show() + +----+-------------------+ + |word| similarity| + +----+-------------------+ + | b|0.29255685145799626| + | c|-0.5414068302988307| + +----+-------------------+ + ... >>> model.transform(doc).head().model DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) """ @@ -1047,6 +1069,24 @@ class Word2VecModel(JavaModel): Model fitted by Word2Vec. """ + def getVectors(self): + """ + Returns the vector representation of the words as a dataframe + with two fields, word and vector. + """ + return self._call_java("getVectors") + + def findSynonyms(self, word, num): + """ + Find "num" number of words closest in similarity to "word". + word can be a string or vector representation. + Returns a dataframe with two fields word and similarity (which + gives the cosine similarity). + """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) + return self._call_java("findSynonyms", word, num) + @inherit_doc class PCA(JavaEstimator, HasInputCol, HasOutputCol): From 98e69467d4fda2c26a951409b5b7c6f1e9345ce4 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 6 Aug 2015 10:29:40 -0700 Subject: [PATCH 0084/1824] [SPARK-9615] [SPARK-9616] [SQL] [MLLIB] Bugs related to FrequentItems when merging and with Tungsten In short: 1- FrequentItems should not use the InternalRow representation, because the keys in the map get messed up. For example, every key in the Map correspond to the very last element observed in the partition, when the elements are strings. 2- Merging two partitions had a bug: **Existing behavior with size 3** Partition A -> Map(1 -> 3, 2 -> 3, 3 -> 4) Partition B -> Map(4 -> 25) Result -> Map() **Correct Behavior:** Partition A -> Map(1 -> 3, 2 -> 3, 3 -> 4) Partition B -> Map(4 -> 25) Result -> Map(3 -> 1, 4 -> 22) cc mengxr rxin JoshRosen Author: Burak Yavuz Closes #7945 from brkyvz/freq-fix and squashes the following commits: 07fa001 [Burak Yavuz] address 2 1dc61a8 [Burak Yavuz] address 1 506753e [Burak Yavuz] fixed and added reg test 47bfd50 [Burak Yavuz] pushing --- .../sql/execution/stat/FrequentItems.scala | 26 +++++++++++-------- .../apache/spark/sql/DataFrameStatSuite.scala | 24 ++++++++++++++--- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 9329148aa233..db463029aedf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -20,17 +20,15 @@ package org.apache.spark.sql.execution.stat import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Row, Column, DataFrame} private[sql] object FrequentItems extends Logging { /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ private class FreqItemCounter(size: Int) extends Serializable { val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long] - /** * Add a new example to the counts if it exists, otherwise deduct the count * from existing items. @@ -42,9 +40,15 @@ private[sql] object FrequentItems extends Logging { if (baseMap.size < size) { baseMap += key -> count } else { - // TODO: Make this more efficient... A flatMap? - baseMap.retain((k, v) => v > count) - baseMap.transform((k, v) => v - count) + val minCount = baseMap.values.min + val remainder = count - minCount + if (remainder >= 0) { + baseMap += key -> count // something will get kicked out, so we can add this + baseMap.retain((k, v) => v > minCount) + baseMap.transform((k, v) => v - minCount) + } else { + baseMap.transform((k, v) => v - count) + } } } this @@ -90,12 +94,12 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) }.toArray - val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { val thisMap = counts(i) - val key = row.get(i, colInfo(i)._2) + val key = row.get(i) thisMap.add(key, 1L) i += 1 } @@ -110,13 +114,13 @@ private[sql] object FrequentItems extends Logging { baseCounts } ) - val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_)) - val resultRow = InternalRow(justItems : _*) + val justItems = freqItems.map(m => m.baseMap.keys.toArray) + val resultRow = Row(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) } val schema = StructType(outputCols).toAttributes - new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow))) + new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 07a675e64f52..0e7659f443ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -123,12 +123,30 @@ class DataFrameStatSuite extends QueryTest { val results = df.stat.freqItems(Array("numbers", "letters"), 0.1) val items = results.collect().head - items.getSeq[Int](0) should contain (1) - items.getSeq[String](1) should contain (toLetter(1)) + assert(items.getSeq[Int](0).contains(1)) + assert(items.getSeq[String](1).contains(toLetter(1))) val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1) val items2 = singleColResults.collect().head - items2.getSeq[Double](0) should contain (-1.0) + assert(items2.getSeq[Double](0).contains(-1.0)) + } + + test("Frequent Items 2") { + val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4) + // this is a regression test, where when merging partitions, we omitted values with higher + // counts than those that existed in the map when the map was full. This test should also fail + // if anything like SPARK-9614 is observed once again + val df = rows.mapPartitionsWithIndex { (idx, iter) => + if (idx == 3) { // must come from one of the later merges, therefore higher partition index + Iterator("3", "3", "3", "3", "3") + } else { + Iterator("0", "1", "2", "3", "4") + } + }.toDF("a") + val results = df.stat.freqItems(Array("a"), 0.25) + val items = results.collect().head.getSeq[String](0) + assert(items.contains("3")) + assert(items.length === 1) } test("sampleBy") { From 5e1b0ef07942a041195b3decd05d86c289bc8d2b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Aug 2015 10:39:16 -0700 Subject: [PATCH 0085/1824] [SPARK-9659][SQL] Rename inSet to isin to match Pandas function. Inspiration drawn from this blog post: https://lab.getbase.com/pandarize-spark-dataframes/ Author: Reynold Xin Closes #7977 from rxin/isin and squashes the following commits: 9b1d3d6 [Reynold Xin] Added return. 2197d37 [Reynold Xin] Fixed test case. 7c1b6cf [Reynold Xin] Import warnings. 4f4a35d [Reynold Xin] [SPARK-9659][SQL] Rename inSet to isin to match Pandas function. --- python/pyspark/sql/column.py | 20 ++++++++++++++++++- .../scala/org/apache/spark/sql/Column.scala | 13 +++++++++++- .../spark/sql/ColumnExpressionSuite.scala | 14 ++++++------- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 0a85da7443d3..8af8637cf948 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -16,6 +16,7 @@ # import sys +import warnings if sys.version >= '3': basestring = str @@ -254,12 +255,29 @@ def inSet(self, *cols): [Row(age=5, name=u'Bob')] >>> df[df.age.inSet([1, 2, 3])].collect() [Row(age=2, name=u'Alice')] + + .. note:: Deprecated in 1.5, use :func:`Column.isin` instead. + """ + warnings.warn("inSet is deprecated. Use isin() instead.") + return self.isin(*cols) + + @ignore_unicode_prefix + @since(1.5) + def isin(self, *cols): + """ + A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.isin("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.isin([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] """ if len(cols) == 1 and isinstance(cols[0], (list, set)): cols = cols[0] cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] sc = SparkContext._active_spark_context - jc = getattr(self._jc, "in")(_to_seq(sc, cols)) + jc = getattr(self._jc, "isin")(_to_seq(sc, cols)) return Column(jc) # order diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b25dcbca82b9..75365fbcec75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -627,8 +627,19 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ + @deprecated("use isin", "1.5.0") @scala.annotation.varargs - def in(list: Any*): Column = In(expr, list.map(lit(_).expr)) + def in(list: Any*): Column = isin(list : _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the evaluated values of the arguments. + * + * @group expr_ops + * @since 1.5.0 + */ + @scala.annotation.varargs + def isin(list: Any*): Column = In(expr, list.map(lit(_).expr)) /** * SQL like expression. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b35138037325..e1b3443d7499 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -345,23 +345,23 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { test("in") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - checkAnswer(df.filter($"a".in(1, 2)), + checkAnswer(df.filter($"a".isin(1, 2)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".in(3, 2)), + checkAnswer(df.filter($"a".isin(3, 2)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".in(3, 1)), + checkAnswer(df.filter($"a".isin(3, 1)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"b".in("y", "x")), + checkAnswer(df.filter($"b".isin("y", "x")), df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".in("z", "x")), + checkAnswer(df.filter($"b".isin("z", "x")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".in("z", "y")), + checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") intercept[AnalysisException] { - df2.filter($"a".in($"b")) + df2.filter($"a".isin($"b")) } } From 6e009cb9c4d7a395991e10dab427f37019283758 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 6 Aug 2015 10:40:54 -0700 Subject: [PATCH 0086/1824] [SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info Author: Wenchen Fan Closes #7955 from cloud-fan/toSeq and squashes the following commits: 21665e2 [Wenchen Fan] fix hive again... 4addf29 [Wenchen Fan] fix hive bc16c59 [Wenchen Fan] minor fix 33d802c [Wenchen Fan] pass data type info to InternalRow.toSeq 3dd033e [Wenchen Fan] move the default special getters implementation from InternalRow to BaseGenericInternalRow --- .../spark/sql/catalyst/InternalRow.scala | 132 ++---------------- .../sql/catalyst/expressions/Projection.scala | 12 +- .../expressions/SpecificMutableRow.scala | 5 +- .../codegen/GenerateProjection.scala | 8 +- .../spark/sql/catalyst/expressions/rows.scala | 132 +++++++++++++++++- .../expressions/CodeGenerationSuite.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 51 +++---- .../columnar/InMemoryColumnarTableScan.scala | 11 +- .../spark/sql/execution/debug/package.scala | 4 +- .../apache/spark/sql/sources/interfaces.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 54 +++---- .../spark/sql/hive/HiveInspectors.scala | 6 +- .../hive/execution/ScriptTransformation.scala | 21 ++- .../spark/sql/hive/hiveWriterContainers.scala | 24 ++-- .../spark/sql/hive/HiveInspectorSuite.scala | 10 +- 15 files changed, 259 insertions(+), 217 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 7d17cca80879..85b4bf3b6aef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{DataType, StructType} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -32,8 +31,6 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString - override def toString: String = mkString("[", ",", "]") - /** * Make a copy of the current [[InternalRow]] object. */ @@ -50,136 +47,25 @@ abstract class InternalRow extends SpecializedGetters with Serializable { false } - // Subclasses of InternalRow should implement all special getters and equals/hashCode, - // or implement this genericGet. - protected def genericGet(ordinal: Int): Any = throw new IllegalStateException( - "Concrete internal rows should implement genericGet, " + - "or implement all special getters and equals/hashCode") - - // default implementation (slow) - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[InternalRow]) { - return false - } - - val other = o.asInstanceOf[InternalRow] - if (other eq null) { - return false - } - - val len = numFields - if (len != other.numFields) { - return false - } - - var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } - } - i += 1 - } - true - } - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - var i = 0 - val len = numFields - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - genericGet(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } - /* ---------------------- utility methods for Scala ---------------------- */ /** * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - // todo: remove this as it needs the generic getter - def toSeq: Seq[Any] = { - val n = numFields - val values = new Array[Any](n) + def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + val len = numFields + assert(len == fieldTypes.length) + + val values = new Array[Any](len) var i = 0 - while (i < n) { - values.update(i, genericGet(i)) + while (i < len) { + values(i) = get(i, fieldTypes(i)) i += 1 } values } - /** Displays all elements of this sequence in a string (without a separator). */ - def mkString: String = toSeq.mkString - - /** Displays all elements of this sequence in a string using a separator string. */ - def mkString(sep: String): String = toSeq.mkString(sep) - - /** - * Displays all elements of this traversable or iterator in a string using - * start, end, and separator strings. - */ - def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4296b4b123fc..59ce7fc4f2c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -203,7 +203,11 @@ class JoinedRow extends InternalRow { this } - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + assert(fieldTypes.length == row1.numFields + row2.numFields) + val (left, right) = fieldTypes.splitAt(row1.numFields) + row1.toSeq(left) ++ row2.toSeq(right) + } override def numFields: Int = row1.numFields + row2.numFields @@ -276,11 +280,11 @@ class JoinedRow extends InternalRow { if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" } else if (row1 eq null) { - row2.mkString("[", ",", "]") + row2.toString } else if (row2 eq null) { - row1.mkString("[", ",", "]") + row1.toString } else { - mkString("[", ",", "]") + s"{${row1.toString} + ${row2.toString}}" } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index b94df6bd66e0..4f56f94bd4ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -192,7 +192,8 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { +final class SpecificMutableRow(val values: Array[MutableValue]) + extends MutableRow with BaseGenericInternalRow { def this(dataTypes: Seq[DataType]) = this( @@ -213,8 +214,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def numFields: Int = values.length - override def toSeq: Seq[Any] = values.map(_.boxed) - override def setNullAt(i: Int): Unit = { values(i).isNull = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c04fe734d554..c744e84d822e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -25,6 +26,8 @@ import org.apache.spark.sql.types._ */ abstract class BaseProjection extends Projection {} +abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow + /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] @@ -171,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificRow((InternalRow) r); } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { $columns @@ -184,7 +187,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - protected Object genericGet(int i) { + @Override + public Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 7657fb535dcf..207e66779266 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -21,6 +21,130 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +/** + * An extended version of [[InternalRow]] that implements all special getters, toString + * and equals/hashCode by `genericGet`. + */ +trait BaseGenericInternalRow extends InternalRow { + + protected def genericGet(ordinal: Int): Any + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def toString(): String = { + if (numFields == 0) { + "[empty row]" + } else { + val sb = new StringBuilder + sb.append("[") + sb.append(genericGet(0)) + val len = numFields + var i = 1 + while (i < len) { + sb.append(",") + sb.append(genericGet(i)) + i += 1 + } + sb.append("]") + sb.toString() + } + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[BaseGenericInternalRow]) { + return false + } + + val other = o.asInstanceOf[BaseGenericInternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numFields + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} + /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. @@ -82,7 +206,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -90,7 +214,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRo override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length @@ -109,7 +233,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(values: Array[Any]) extends MutableRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -117,7 +241,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow { override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e310aee22166..e323467af5f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions)() - val actual = plan(new GenericMutableRow(length)).toSeq + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index af1a8ecca9b5..5cbd52bc0590 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: InternalRow + def collectedStatistics: GenericInternalRow } /** @@ -75,7 +75,8 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -92,8 +93,8 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class ByteColumnStats extends ColumnStats { @@ -110,8 +111,8 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class ShortColumnStats extends ColumnStats { @@ -128,8 +129,8 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class IntColumnStats extends ColumnStats { @@ -146,8 +147,8 @@ private[sql] class IntColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class LongColumnStats extends ColumnStats { @@ -164,8 +165,8 @@ private[sql] class LongColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class FloatColumnStats extends ColumnStats { @@ -182,8 +183,8 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -200,8 +201,8 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class StringColumnStats extends ColumnStats { @@ -218,8 +219,8 @@ private[sql] class StringColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -230,8 +231,8 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { @@ -248,8 +249,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { @@ -262,8 +263,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } private[sql] class DateColumnStats extends IntColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 5d5b0697d701..d553bb6169ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation( } val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.toSeq)) + .flatMap(_.values)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -330,10 +330,11 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema - .zip(cachedBatch.stats.toSeq) - .map { case (a, s) => s"${a.name}: $s" } - .mkString(", ") + def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { + case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") logInfo(s"Skipping partition based on stats $statsString") false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c37007f1eece..dd3858ea2b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -156,8 +156,8 @@ package object debug { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => - case (row: InternalRow, StructType(fields)) => - row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } + case (row: InternalRow, s: StructType) => + row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (a: ArrayData, ArrayType(elemType, _)) => a.foreach(elemType, (_, e) => { typeCheck(e, elemType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7126145ddc01..c04557e5a081 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -461,8 +461,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val spec = discoverPartitions() val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = values.toSeq.zip(partitionColumnTypes).map { - case (value, dataType) => Literal.create(value, dataType) + val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => + Literal.create(values.get(i, dt), dt) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 16e0187ed20a..d0430d2a60e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,33 +19,36 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - InternalRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) + createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, - InternalRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testDecimalColumnStats(InternalRow(null, null, 0)) + createRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) + testDecimalColumnStats(createRow(null, null, 0)) + + def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: InternalRow): Unit = { + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } @@ -61,11 +64,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) - assertResult(10, "Wrong null count")(stats.get(2, null)) - assertResult(20, "Wrong row count")(stats.get(3, null)) - assertResult(stats.get(4, null), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -73,14 +76,15 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName val columnType = FIXED_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new FixedDecimalColumnStats(15, 10) - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } @@ -96,11 +100,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) - assertResult(10, "Wrong null count")(stats.get(2, null)) - assertResult(20, "Wrong row count")(stats.get(3, null)) - assertResult(stats.get(4, null), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 39d798d072ae..9824dad23959 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -390,8 +390,10 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + val row = o.asInstanceOf[InternalRow] + soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) } struct } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index a6a343d39599..ade27454b9d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -88,6 +88,7 @@ case class ScriptTransformation( // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( inputIterator, + input.map(_.dataType), outputProjection, inputSerde, inputSoi, @@ -201,6 +202,7 @@ case class ScriptTransformation( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], + inputSchema: Seq[DataType], outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: ObjectInspector, @@ -226,12 +228,25 @@ private class ScriptTransformationWriterThread( // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception var threwException: Boolean = true + val len = inputSchema.length try { iter.map(outputProjection).foreach { row => if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - outputStream.write(data) + val data = if (len == 0) { + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") + } else { + val sb = new StringBuilder + sb.append(row.get(0, inputSchema(0))) + var i = 1 + while (i < len) { + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + sb.append(row.get(i, inputSchema(i))) + i += 1 + } + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) + sb.toString() + } + outputStream.write(data.getBytes("utf-8")) } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 684ea1d137b4..8dc796b056a7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -211,18 +211,18 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } } - val dynamicPartPath = dynamicPartColNames - .zip(row.toSeq.takeRight(dynamicPartColNames.length)) - .map { case (col, rawVal) => - val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$col=$colString" - }.mkString + val nonDynamicPartLen = row.numFields - dynamicPartColNames.length + val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => + val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) + val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$colName=$colString" + }.mkString def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 99e95fb92130..81a70b8d4226 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -133,8 +133,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { - row1.zip(row2.toSeq).foreach { case (r1, r2) => + def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { + row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => checkValue(r1, r2) } } @@ -211,8 +211,10 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) val inspector = toInspector(dt) - checkValues(row, - unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) + checkValues( + row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], + dt) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } From 2eca46a17a3d46a605804ff89c010017da91e1bc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 11:15:37 -0700 Subject: [PATCH 0087/1824] Revert "[SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info" This reverts commit 6e009cb9c4d7a395991e10dab427f37019283758. --- .../spark/sql/catalyst/InternalRow.scala | 132 ++++++++++++++++-- .../sql/catalyst/expressions/Projection.scala | 12 +- .../expressions/SpecificMutableRow.scala | 5 +- .../codegen/GenerateProjection.scala | 8 +- .../spark/sql/catalyst/expressions/rows.scala | 132 +----------------- .../expressions/CodeGenerationSuite.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 51 ++++--- .../columnar/InMemoryColumnarTableScan.scala | 11 +- .../spark/sql/execution/debug/package.scala | 4 +- .../apache/spark/sql/sources/interfaces.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 54 ++++--- .../spark/sql/hive/HiveInspectors.scala | 6 +- .../hive/execution/ScriptTransformation.scala | 21 +-- .../spark/sql/hive/hiveWriterContainers.scala | 24 ++-- .../spark/sql/hive/HiveInspectorSuite.scala | 10 +- 15 files changed, 217 insertions(+), 259 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 85b4bf3b6aef..7d17cca80879 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -31,6 +32,8 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString + override def toString: String = mkString("[", ",", "]") + /** * Make a copy of the current [[InternalRow]] object. */ @@ -47,25 +50,136 @@ abstract class InternalRow extends SpecializedGetters with Serializable { false } + // Subclasses of InternalRow should implement all special getters and equals/hashCode, + // or implement this genericGet. + protected def genericGet(ordinal: Int): Any = throw new IllegalStateException( + "Concrete internal rows should implement genericGet, " + + "or implement all special getters and equals/hashCode") + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[InternalRow]) { + return false + } + + val other = o.asInstanceOf[InternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numFields + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } + /* ---------------------- utility methods for Scala ---------------------- */ /** * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { - val len = numFields - assert(len == fieldTypes.length) - - val values = new Array[Any](len) + // todo: remove this as it needs the generic getter + def toSeq: Seq[Any] = { + val n = numFields + val values = new Array[Any](n) var i = 0 - while (i < len) { - values(i) = get(i, fieldTypes(i)) + while (i < n) { + values.update(i, genericGet(i)) i += 1 } values } - def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 59ce7fc4f2c6..4296b4b123fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -203,11 +203,7 @@ class JoinedRow extends InternalRow { this } - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { - assert(fieldTypes.length == row1.numFields + row2.numFields) - val (left, right) = fieldTypes.splitAt(row1.numFields) - row1.toSeq(left) ++ row2.toSeq(right) - } + override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq override def numFields: Int = row1.numFields + row2.numFields @@ -280,11 +276,11 @@ class JoinedRow extends InternalRow { if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" } else if (row1 eq null) { - row2.toString + row2.mkString("[", ",", "]") } else if (row2 eq null) { - row1.toString + row1.mkString("[", ",", "]") } else { - s"{${row1.toString} + ${row2.toString}}" + mkString("[", ",", "]") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 4f56f94bd4ca..b94df6bd66e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -192,8 +192,7 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) - extends MutableRow with BaseGenericInternalRow { +final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { def this(dataTypes: Seq[DataType]) = this( @@ -214,6 +213,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) override def numFields: Int = values.length + override def toSeq: Seq[Any] = values.map(_.boxed) + override def setNullAt(i: Int): Unit = { values(i).isNull = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c744e84d822e..c04fe734d554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -26,8 +25,6 @@ import org.apache.spark.sql.types._ */ abstract class BaseProjection extends Projection {} -abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow - /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] @@ -174,7 +171,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificRow((InternalRow) r); } - final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { + final class SpecificRow extends ${classOf[MutableRow].getName} { $columns @@ -187,8 +184,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - @Override - public Object genericGet(int i) { + protected Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 207e66779266..7657fb535dcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -21,130 +21,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** - * An extended version of [[InternalRow]] that implements all special getters, toString - * and equals/hashCode by `genericGet`. - */ -trait BaseGenericInternalRow extends InternalRow { - - protected def genericGet(ordinal: Int): Any - - // default implementation (slow) - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def toString(): String = { - if (numFields == 0) { - "[empty row]" - } else { - val sb = new StringBuilder - sb.append("[") - sb.append(genericGet(0)) - val len = numFields - var i = 1 - while (i < len) { - sb.append(",") - sb.append(genericGet(i)) - i += 1 - } - sb.append("]") - sb.toString() - } - } - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[BaseGenericInternalRow]) { - return false - } - - val other = o.asInstanceOf[BaseGenericInternalRow] - if (other eq null) { - return false - } - - val len = numFields - if (len != other.numFields) { - return false - } - - var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } - } - i += 1 - } - true - } - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - var i = 0 - val len = numFields - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - genericGet(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } -} - /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. @@ -206,7 +82,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -214,7 +90,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq: Seq[Any] = values override def numFields: Int = values.length @@ -233,7 +109,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -241,7 +117,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericI override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq: Seq[Any] = values override def numFields: Int = values.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e323467af5f4..e310aee22166 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions)() - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val actual = plan(new GenericMutableRow(length)).toSeq val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 5cbd52bc0590..af1a8ecca9b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: GenericInternalRow + def collectedStatistics: InternalRow } /** @@ -75,8 +75,7 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) + override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -93,8 +92,8 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { @@ -111,8 +110,8 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { @@ -129,8 +128,8 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class IntColumnStats extends ColumnStats { @@ -147,8 +146,8 @@ private[sql] class IntColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class LongColumnStats extends ColumnStats { @@ -165,8 +164,8 @@ private[sql] class LongColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { @@ -183,8 +182,8 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -201,8 +200,8 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { @@ -219,8 +218,8 @@ private[sql] class StringColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -231,8 +230,8 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { @@ -249,8 +248,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) } private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { @@ -263,8 +262,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: InternalRow = + InternalRow(null, null, nullCount, count, sizeInBytes) } private[sql] class DateColumnStats extends IntColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index d553bb6169ec..5d5b0697d701 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation( } val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.values)) + .flatMap(_.toSeq)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -330,11 +330,10 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { - case (a, i) => - val value = cachedBatch.stats.get(i, a.dataType) - s"${a.name}: $value" - }.mkString(", ") + def statsString: String = relation.partitionStatistics.schema + .zip(cachedBatch.stats.toSeq) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") logInfo(s"Skipping partition based on stats $statsString") false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dd3858ea2b52..c37007f1eece 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -156,8 +156,8 @@ package object debug { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => - case (row: InternalRow, s: StructType) => - row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } + case (row: InternalRow, StructType(fields)) => + row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (a: ArrayData, ArrayType(elemType, _)) => a.foreach(elemType, (_, e) => { typeCheck(e, elemType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c04557e5a081..7126145ddc01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -461,8 +461,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val spec = discoverPartitions() val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => - Literal.create(values.get(i, dt), dt) + val literals = values.toSeq.zip(partitionColumnTypes).map { + case (value, dataType) => Literal.create(value, dataType) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index d0430d2a60e7..16e0187ed20a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,36 +19,33 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - createRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) + InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, - createRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) - testDecimalColumnStats(createRow(null, null, 0)) - - def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) + InternalRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) + testDecimalColumnStats(InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: GenericInternalRow): Unit = { + initialStatistics: InternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { case (actual, expected) => assert(actual === expected) } } @@ -64,11 +61,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) + assertResult(10, "Wrong null count")(stats.get(2, null)) + assertResult(20, "Wrong row count")(stats.get(3, null)) + assertResult(stats.get(4, null), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -76,15 +73,14 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( - initialStatistics: GenericInternalRow): Unit = { + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName val columnType = FIXED_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new FixedDecimalColumnStats(15, 10) - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { case (actual, expected) => assert(actual === expected) } } @@ -100,11 +96,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) + assertResult(10, "Wrong null count")(stats.get(2, null)) + assertResult(20, "Wrong row count")(stats.get(3, null)) + assertResult(stats.get(4, null), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 9824dad23959..39d798d072ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -390,10 +390,8 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - val row = o.asInstanceOf[InternalRow] - soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach { - case ((field, wrapper), i) => - soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) } struct } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index ade27454b9d2..a6a343d39599 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -88,7 +88,6 @@ case class ScriptTransformation( // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( inputIterator, - input.map(_.dataType), outputProjection, inputSerde, inputSoi, @@ -202,7 +201,6 @@ case class ScriptTransformation( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], - inputSchema: Seq[DataType], outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: ObjectInspector, @@ -228,25 +226,12 @@ private class ScriptTransformationWriterThread( // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception var threwException: Boolean = true - val len = inputSchema.length try { iter.map(outputProjection).foreach { row => if (inputSerde == null) { - val data = if (len == 0) { - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") - } else { - val sb = new StringBuilder - sb.append(row.get(0, inputSchema(0))) - var i = 1 - while (i < len) { - sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - sb.append(row.get(i, inputSchema(i))) - i += 1 - } - sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) - sb.toString() - } - outputStream.write(data.getBytes("utf-8")) + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + outputStream.write(data) } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8dc796b056a7..684ea1d137b4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -211,18 +211,18 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } } - val nonDynamicPartLen = row.numFields - dynamicPartColNames.length - val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => - val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) - val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$colName=$colString" - }.mkString + val dynamicPartPath = dynamicPartColNames + .zip(row.toSeq.takeRight(dynamicPartColNames.length)) + .map { case (col, rawVal) => + val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$col=$colString" + }.mkString def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 81a70b8d4226..99e95fb92130 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -133,8 +133,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { - row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => + def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { + row1.zip(row2.toSeq).foreach { case (r1, r2) => checkValue(r1, r2) } } @@ -211,10 +211,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) val inspector = toInspector(dt) - checkValues( - row, - unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], - dt) + checkValues(row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } From cdd53b762bf358616b313e3334b5f6945caf9ab1 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 6 Aug 2015 11:15:54 -0700 Subject: [PATCH 0088/1824] [SPARK-9632] [SQL] [HOT-FIX] Fix build. seems https://github.com/apache/spark/pull/7955 breaks the build. Author: Yin Huai Closes #8001 from yhuai/SPARK-9632-fixBuild and squashes the following commits: 6c257dd [Yin Huai] Fix build. --- .../scala/org/apache/spark/sql/catalyst/expressions/rows.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 7657fb535dcf..fd42fac3d2cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. From 0d7aac99da660cc42eb5a9be8e262bd9bd8a770f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 6 Aug 2015 19:29:42 +0100 Subject: [PATCH 0089/1824] [SPARK-9641] [DOCS] spark.shuffle.service.port is not documented Document spark.shuffle.service.{enabled,port} CC sryza tgravescs This is pretty minimal; is there more to say here about the service? Author: Sean Owen Closes #7991 from srowen/SPARK-9641 and squashes the following commits: 3bb946e [Sean Owen] Add link to docs for setup and config of external shuffle service 2302e01 [Sean Owen] Document spark.shuffle.service.{enabled,port} --- docs/configuration.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 24b606356a14..c60dd16839c0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -473,6 +473,25 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryFraction. + + spark.shuffle.service.enabled + false + + Enables the external shuffle service. This service preserves the shuffle files written by + executors so the executors can be safely removed. This must be enabled if + spark.dynamicAllocation.enabled is "true". The external shuffle service + must be set up in order to enable it. See + dynamic allocation + configuration and setup documentation for more information. + + + + spark.shuffle.service.port + 7337 + + Port on which the external shuffle service will run. + + spark.shuffle.sort.bypassMergeThreshold 200 From a1bbf1bc5c51cd796015ac159799cf024de6fa07 Mon Sep 17 00:00:00 2001 From: Nilanjan Raychaudhuri Date: Thu, 6 Aug 2015 12:50:08 -0700 Subject: [PATCH 0090/1824] [SPARK-8978] [STREAMING] Implements the DirectKafkaRateController MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Dean Wampler Author: Nilanjan Raychaudhuri Author: François Garillot Closes #7796 from dragos/topic/streaming-bp/kafka-direct and squashes the following commits: 50d1f21 [Nilanjan Raychaudhuri] Taking care of the remaining nits 648c8b1 [Dean Wampler] Refactored rate controller test to be more predictable and run faster. e43f678 [Nilanjan Raychaudhuri] fixing doc and nits ce19d2a [Dean Wampler] Removing an unreliable assertion. 9615320 [Dean Wampler] Give me a break... 6372478 [Dean Wampler] Found a few ways to make this test more robust... 9e69e37 [Dean Wampler] Attempt to fix flakey test that fails in CI, but not locally :( d3db1ea [Dean Wampler] Fixing stylecheck errors. d04a288 [Nilanjan Raychaudhuri] adding test to make sure rate controller is used to calculate maxMessagesPerPartition b6ecb67 [Nilanjan Raychaudhuri] Fixed styling issue 3110267 [Nilanjan Raychaudhuri] [SPARK-8978][Streaming] Implements the DirectKafkaRateController 393c580 [François Garillot] [SPARK-8978][Streaming] Implements the DirectKafkaRateController 51e78c6 [Nilanjan Raychaudhuri] Rename and fix build failure 2795509 [Nilanjan Raychaudhuri] Added missing RateController 19200f5 [Dean Wampler] Removed usage of infix notation. Changed a private variable name to be more consistent with usage. aa4a70b [François Garillot] [SPARK-8978][Streaming] Implements the DirectKafkaController --- .../kafka/DirectKafkaInputDStream.scala | 47 ++++++++-- .../kafka/DirectKafkaStreamSuite.scala | 89 +++++++++++++++++++ 2 files changed, 127 insertions(+), 9 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 48a1933d92f8..8a177077775c 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -29,7 +29,8 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.StreamInputInfo +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -61,7 +62,7 @@ class DirectKafkaInputDStream[ val kafkaParams: Map[String, String], val fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R -) extends InputDStream[R](ssc_) with Logging { + ) extends InputDStream[R](ssc_) with Logging { val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) @@ -71,14 +72,35 @@ class DirectKafkaInputDStream[ protected[streaming] override val checkpointData = new DirectKafkaInputDStreamCheckpointData + + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + Some(new DirectKafkaRateController(id, + RateEstimator.create(ssc.conf, ssc_.graph.batchDuration))) + } else { + None + } + } + protected val kc = new KafkaCluster(kafkaParams) - protected val maxMessagesPerPartition: Option[Long] = { - val ratePerSec = context.sparkContext.getConf.getInt( + private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRatePerPartition", 0) - if (ratePerSec > 0) { + protected def maxMessagesPerPartition: Option[Long] = { + val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val numPartitions = currentOffsets.keys.size + + val effectiveRateLimitPerPartition = estimatedRateLimit + .filter(_ > 0) + .map(limit => Math.min(maxRateLimitPerPartition, (limit / numPartitions))) + .getOrElse(maxRateLimitPerPartition) + + if (effectiveRateLimitPerPartition > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 - Some((secsPerBatch * ratePerSec).toLong) + Some((secsPerBatch * effectiveRateLimitPerPartition).toLong) } else { None } @@ -170,11 +192,18 @@ class DirectKafkaInputDStream[ val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => - logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") - generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( - context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) + logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") + generatedRDDs += t -> new KafkaRDD[K, V, U, T, R]( + context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler) } } } + /** + * A RateController to retrieve the rate from RateEstimator. + */ + private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = () + } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 5b3c79444aa6..02225d5aa7cc 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -20,6 +20,9 @@ package org.apache.spark.streaming.kafka import java.io.File import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.scheduler.rate.RateEstimator + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -350,6 +353,77 @@ class DirectKafkaStreamSuite ssc.stop() } + test("using rate controller") { + val topic = "backpressure" + val topicPartition = TopicAndPartition(topic, 0) + kafkaTestUtils.createTopic(topic) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 100 + val estimator = new ConstantEstimator(100) + val messageKeys = (1 to 200).map(_.toString) + val messages = messageKeys.map((_, 1)).toMap + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(Set(topicPartition)) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, estimator)) + } + } + + val collectedData = + new mutable.ArrayBuffer[Array[String]]() with mutable.SynchronizedBuffer[Array[String]] + + // Used for assertion failure messages. + def dataToString: String = + collectedData.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + collectedData += data + } + + ssc.start() + + // Try different rate limits. + // Send data to Kafka and wait for arrays of data to appear matching the rate. + Seq(100, 50, 20).foreach { rate => + collectedData.clear() // Empty this buffer on each pass. + estimator.updateRate(rate) // Set a new rate. + // Expect blocks of data equal to "rate", scaled by the interval length in secs. + val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001) + kafkaTestUtils.sendMessages(topic, messages) + eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) { + // Assert that rate estimator values are used to determine maxMessagesPerPartition. + // Funky "-" in message makes the complete assertion message read better. + assert(collectedData.exists(_.size == expectedSize), + s" - No arrays of size $expectedSize for rate $rate found in $dataToString") + } + } + + ssc.stop() + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { @@ -381,3 +455,18 @@ object DirectKafkaStreamSuite { } } } + +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { + + def updateRate(newRate: Long): Unit = { + rate = newRate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(rate) +} + From 1f62f104c7a2aeac625b17d9e5ac62f1f10a2b21 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 6 Aug 2015 13:11:59 -0700 Subject: [PATCH 0091/1824] [SPARK-9632][SQL] update InternalRow.toSeq to make it accept data type info This re-applies #7955, which was reverted due to a race condition to fix build breaking. Author: Wenchen Fan Author: Reynold Xin Closes #8002 from rxin/InternalRow-toSeq and squashes the following commits: 332416a [Reynold Xin] Merge pull request #7955 from cloud-fan/toSeq 21665e2 [Wenchen Fan] fix hive again... 4addf29 [Wenchen Fan] fix hive bc16c59 [Wenchen Fan] minor fix 33d802c [Wenchen Fan] pass data type info to InternalRow.toSeq 3dd033e [Wenchen Fan] move the default special getters implementation from InternalRow to BaseGenericInternalRow --- .../spark/sql/catalyst/InternalRow.scala | 132 ++---------------- .../sql/catalyst/expressions/Projection.scala | 12 +- .../expressions/SpecificMutableRow.scala | 5 +- .../codegen/GenerateProjection.scala | 8 +- .../spark/sql/catalyst/expressions/rows.scala | 132 +++++++++++++++++- .../expressions/CodeGenerationSuite.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 51 +++---- .../columnar/InMemoryColumnarTableScan.scala | 11 +- .../spark/sql/execution/debug/package.scala | 4 +- .../apache/spark/sql/sources/interfaces.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 54 +++---- .../spark/sql/hive/HiveInspectors.scala | 6 +- .../hive/execution/ScriptTransformation.scala | 21 ++- .../spark/sql/hive/hiveWriterContainers.scala | 24 ++-- .../spark/sql/hive/HiveInspectorSuite.scala | 10 +- 15 files changed, 259 insertions(+), 217 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 7d17cca80879..85b4bf3b6aef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{DataType, StructType} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -32,8 +31,6 @@ abstract class InternalRow extends SpecializedGetters with Serializable { // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString - override def toString: String = mkString("[", ",", "]") - /** * Make a copy of the current [[InternalRow]] object. */ @@ -50,136 +47,25 @@ abstract class InternalRow extends SpecializedGetters with Serializable { false } - // Subclasses of InternalRow should implement all special getters and equals/hashCode, - // or implement this genericGet. - protected def genericGet(ordinal: Int): Any = throw new IllegalStateException( - "Concrete internal rows should implement genericGet, " + - "or implement all special getters and equals/hashCode") - - // default implementation (slow) - private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null - override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[InternalRow]) { - return false - } - - val other = o.asInstanceOf[InternalRow] - if (other eq null) { - return false - } - - val len = numFields - if (len != other.numFields) { - return false - } - - var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = genericGet(i) - val o2 = other.genericGet(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } - } - i += 1 - } - true - } - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - var i = 0 - val len = numFields - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - genericGet(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } - /* ---------------------- utility methods for Scala ---------------------- */ /** * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - // todo: remove this as it needs the generic getter - def toSeq: Seq[Any] = { - val n = numFields - val values = new Array[Any](n) + def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + val len = numFields + assert(len == fieldTypes.length) + + val values = new Array[Any](len) var i = 0 - while (i < n) { - values.update(i, genericGet(i)) + while (i < len) { + values(i) = get(i, fieldTypes(i)) i += 1 } values } - /** Displays all elements of this sequence in a string (without a separator). */ - def mkString: String = toSeq.mkString - - /** Displays all elements of this sequence in a string using a separator string. */ - def mkString(sep: String): String = toSeq.mkString(sep) - - /** - * Displays all elements of this traversable or iterator in a string using - * start, end, and separator strings. - */ - def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + def toSeq(schema: StructType): Seq[Any] = toSeq(schema.map(_.dataType)) } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4296b4b123fc..59ce7fc4f2c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -203,7 +203,11 @@ class JoinedRow extends InternalRow { this } - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + assert(fieldTypes.length == row1.numFields + row2.numFields) + val (left, right) = fieldTypes.splitAt(row1.numFields) + row1.toSeq(left) ++ row2.toSeq(right) + } override def numFields: Int = row1.numFields + row2.numFields @@ -276,11 +280,11 @@ class JoinedRow extends InternalRow { if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" } else if (row1 eq null) { - row2.mkString("[", ",", "]") + row2.toString } else if (row2 eq null) { - row1.mkString("[", ",", "]") + row1.toString } else { - mkString("[", ",", "]") + s"{${row1.toString} + ${row2.toString}}" } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index b94df6bd66e0..4f56f94bd4ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -192,7 +192,8 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { +final class SpecificMutableRow(val values: Array[MutableValue]) + extends MutableRow with BaseGenericInternalRow { def this(dataTypes: Seq[DataType]) = this( @@ -213,8 +214,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def numFields: Int = values.length - override def toSeq: Seq[Any] = values.map(_.boxed) - override def setNullAt(i: Int): Unit = { values(i).isNull = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c04fe734d554..c744e84d822e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -25,6 +26,8 @@ import org.apache.spark.sql.types._ */ abstract class BaseProjection extends Projection {} +abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow + /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] @@ -171,7 +174,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificRow((InternalRow) r); } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { $columns @@ -184,7 +187,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - protected Object genericGet(int i) { + @Override + public Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index fd42fac3d2cd..11d10b2d8a48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -22,6 +22,130 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +/** + * An extended version of [[InternalRow]] that implements all special getters, toString + * and equals/hashCode by `genericGet`. + */ +trait BaseGenericInternalRow extends InternalRow { + + protected def genericGet(ordinal: Int): Any + + // default implementation (slow) + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, dataType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def toString(): String = { + if (numFields == 0) { + "[empty row]" + } else { + val sb = new StringBuilder + sb.append("[") + sb.append(genericGet(0)) + val len = numFields + var i = 1 + while (i < len) { + sb.append(",") + sb.append(genericGet(i)) + i += 1 + } + sb.append("]") + sb.toString() + } + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[BaseGenericInternalRow]) { + return false + } + + val other = o.asInstanceOf[BaseGenericInternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numFields + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} + /** * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. @@ -83,7 +207,7 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRow { +class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -91,7 +215,7 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends InternalRo override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length @@ -110,7 +234,7 @@ class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(values: Array[Any]) extends MutableRow { +class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -118,7 +242,7 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow { override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq: Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values override def numFields: Int = values.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e310aee22166..e323467af5f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -87,7 +87,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) val plan = GenerateMutableProjection.generate(expressions)() - val actual = plan(new GenericMutableRow(length)).toSeq + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) if (!checkResult(actual, expected)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index af1a8ecca9b5..5cbd52bc0590 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable { * Column statistics represented as a single row, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: InternalRow + def collectedStatistics: GenericInternalRow } /** @@ -75,7 +75,8 @@ private[sql] sealed trait ColumnStats extends Serializable { private[sql] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) - override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } private[sql] class BooleanColumnStats extends ColumnStats { @@ -92,8 +93,8 @@ private[sql] class BooleanColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class ByteColumnStats extends ColumnStats { @@ -110,8 +111,8 @@ private[sql] class ByteColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class ShortColumnStats extends ColumnStats { @@ -128,8 +129,8 @@ private[sql] class ShortColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class IntColumnStats extends ColumnStats { @@ -146,8 +147,8 @@ private[sql] class IntColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class LongColumnStats extends ColumnStats { @@ -164,8 +165,8 @@ private[sql] class LongColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class FloatColumnStats extends ColumnStats { @@ -182,8 +183,8 @@ private[sql] class FloatColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class DoubleColumnStats extends ColumnStats { @@ -200,8 +201,8 @@ private[sql] class DoubleColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class StringColumnStats extends ColumnStats { @@ -218,8 +219,8 @@ private[sql] class StringColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class BinaryColumnStats extends ColumnStats { @@ -230,8 +231,8 @@ private[sql] class BinaryColumnStats extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { @@ -248,8 +249,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C } } - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { @@ -262,8 +263,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { } } - override def collectedStatistics: InternalRow = - InternalRow(null, null, nullCount, count, sizeInBytes) + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } private[sql] class DateColumnStats extends IntColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 5d5b0697d701..d553bb6169ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation( } val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.toSeq)) + .flatMap(_.values)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) @@ -330,10 +330,11 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema - .zip(cachedBatch.stats.toSeq) - .map { case (a, s) => s"${a.name}: $s" } - .mkString(", ") + def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { + case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") logInfo(s"Skipping partition based on stats $statsString") false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c37007f1eece..dd3858ea2b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -156,8 +156,8 @@ package object debug { def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { case (null, _) => - case (row: InternalRow, StructType(fields)) => - row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } + case (row: InternalRow, s: StructType) => + row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (a: ArrayData, ArrayType(elemType, _)) => a.foreach(elemType, (_, e) => { typeCheck(e, elemType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7126145ddc01..c04557e5a081 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -461,8 +461,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val spec = discoverPartitions() val partitionColumnTypes = spec.partitionColumns.map(_.dataType) val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => - val literals = values.toSeq.zip(partitionColumnTypes).map { - case (value, dataType) => Literal.create(value, dataType) + val literals = partitionColumnTypes.zipWithIndex.map { case (dt, i) => + Literal.create(values.get(i, dt), dt) } val castedValues = partitionSchema.zip(literals).map { case (field, literal) => Cast(literal, field.dataType).eval() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 16e0187ed20a..d0430d2a60e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,33 +19,36 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - InternalRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) + createRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, - InternalRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testDecimalColumnStats(InternalRow(null, null, 0)) + createRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) + testDecimalColumnStats(createRow(null, null, 0)) + + def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: InternalRow): Unit = { + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } @@ -61,11 +64,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) - assertResult(10, "Wrong null count")(stats.get(2, null)) - assertResult(20, "Wrong row count")(stats.get(3, null)) - assertResult(stats.get(4, null), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -73,14 +76,15 @@ class ColumnStatsSuite extends SparkFunSuite { } } - def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( + initialStatistics: GenericInternalRow): Unit = { val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName val columnType = FIXED_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new FixedDecimalColumnStats(15, 10) - columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } } @@ -96,11 +100,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0, null)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1, null)) - assertResult(10, "Wrong null count")(stats.get(2, null)) - assertResult(20, "Wrong row count")(stats.get(3, null)) - assertResult(stats.get(4, null), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) + assertResult(10, "Wrong null count")(stats.values(2)) + assertResult(20, "Wrong row count")(stats.values(3)) + assertResult(stats.values(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 39d798d072ae..9824dad23959 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -390,8 +390,10 @@ private[hive] trait HiveInspectors { (o: Any) => { if (o != null) { val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[InternalRow].toSeq).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + val row = o.asInstanceOf[InternalRow] + soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach { + case ((field, wrapper), i) => + soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) } struct } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index a6a343d39599..ade27454b9d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -88,6 +88,7 @@ case class ScriptTransformation( // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( inputIterator, + input.map(_.dataType), outputProjection, inputSerde, inputSoi, @@ -201,6 +202,7 @@ case class ScriptTransformation( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], + inputSchema: Seq[DataType], outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: ObjectInspector, @@ -226,12 +228,25 @@ private class ScriptTransformationWriterThread( // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception var threwException: Boolean = true + val len = inputSchema.length try { iter.map(outputProjection).foreach { row => if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - outputStream.write(data) + val data = if (len == 0) { + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") + } else { + val sb = new StringBuilder + sb.append(row.get(0, inputSchema(0))) + var i = 1 + while (i < len) { + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + sb.append(row.get(i, inputSchema(i))) + i += 1 + } + sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) + sb.toString() + } + outputStream.write(data.getBytes("utf-8")) } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 684ea1d137b4..8dc796b056a7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -211,18 +211,18 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( } } - val dynamicPartPath = dynamicPartColNames - .zip(row.toSeq.takeRight(dynamicPartColNames.length)) - .map { case (col, rawVal) => - val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) - val colString = - if (string == null || string.isEmpty) { - defaultPartName - } else { - FileUtils.escapePathName(string, defaultPartName) - } - s"/$col=$colString" - }.mkString + val nonDynamicPartLen = row.numFields - dynamicPartColNames.length + val dynamicPartPath = dynamicPartColNames.zipWithIndex.map { case (colName, i) => + val rawVal = row.get(nonDynamicPartLen + i, schema(colName).dataType) + val string = if (rawVal == null) null else convertToHiveRawString(colName, rawVal) + val colString = + if (string == null || string.isEmpty) { + defaultPartName + } else { + FileUtils.escapePathName(string, defaultPartName) + } + s"/$colName=$colString" + }.mkString def newWriter(): FileSinkOperator.RecordWriter = { val newFileSinkDesc = new FileSinkDesc( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 99e95fb92130..81a70b8d4226 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -133,8 +133,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } } - def checkValues(row1: Seq[Any], row2: InternalRow): Unit = { - row1.zip(row2.toSeq).foreach { case (r1, r2) => + def checkValues(row1: Seq[Any], row2: InternalRow, row2Schema: StructType): Unit = { + row1.zip(row2.toSeq(row2Schema)).foreach { case (r1, r2) => checkValue(r1, r2) } } @@ -211,8 +211,10 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case (t, idx) => StructField(s"c_$idx", t) }) val inspector = toInspector(dt) - checkValues(row, - unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) + checkValues( + row, + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow], + dt) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } From 54c0789a05a783ce90e0e9848079be442a82966b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 6 Aug 2015 13:29:31 -0700 Subject: [PATCH 0092/1824] [SPARK-9493] [ML] add featureIndex to handle vector features in IsotonicRegression This PR contains the following changes: * add `featureIndex` to handle vector features (in order to chain isotonic regression easily with output from logistic regression * make getter/setter names consistent with params * remove inheritance from Regressor because it is tricky to handle both `DoubleType` and `VectorType` * simplify test data generation jkbradley zapletal-martin Author: Xiangrui Meng Closes #7952 from mengxr/SPARK-9493 and squashes the following commits: 8818ac3 [Xiangrui Meng] address comments 05e2216 [Xiangrui Meng] address comments 8d08090 [Xiangrui Meng] add featureIndex to handle vector features make getter/setter names consistent with params remove inheritance from Regressor --- .../ml/regression/IsotonicRegression.scala | 202 +++++++++++++----- .../regression/IsotonicRegressionSuite.scala | 82 ++++--- 2 files changed, 194 insertions(+), 90 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 4ece8cf8cf0b..f570590960a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -17,44 +17,113 @@ package org.apache.spark.ml.regression +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.PredictorParams -import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} -import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{DoubleType, DataType} -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel /** * Params for isotonic regression. */ -private[regression] trait IsotonicRegressionParams extends PredictorParams { +private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol + with HasLabelCol with HasPredictionCol with Logging { /** - * Param for weight column name. - * TODO: Move weightCol to sharedParams. - * + * Param for weight column name (default: none). * @group param */ + // TODO: Move weightCol to sharedParams. final val weightCol: Param[String] = - new Param[String](this, "weightCol", "weight column name") + new Param[String](this, "weightCol", + "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") /** @group getParam */ final def getWeightCol: String = $(weightCol) /** - * Param for isotonic parameter. - * Isotonic (increasing) or antitonic (decreasing) sequence. + * Param for whether the output sequence should be isotonic/increasing (true) or + * antitonic/decreasing (false). * @group param */ final val isotonic: BooleanParam = - new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence") + new BooleanParam(this, "isotonic", + "whether the output sequence should be isotonic/increasing (true) or" + + "antitonic/decreasing (false)") /** @group getParam */ - final def getIsotonicParam: Boolean = $(isotonic) + final def getIsotonic: Boolean = $(isotonic) + + /** + * Param for the index of the feature if [[featuresCol]] is a vector column (default: `0`), no + * effect otherwise. + * @group param + */ + final val featureIndex: IntParam = new IntParam(this, "featureIndex", + "The index of the feature if featuresCol is a vector column, no effect otherwise.") + + /** @group getParam */ + final def getFeatureIndex: Int = $(featureIndex) + + setDefault(isotonic -> true, featureIndex -> 0) + + /** Checks whether the input has weight column. */ + protected[ml] def hasWeightCol: Boolean = { + isDefined(weightCol) && $(weightCol) != "" + } + + /** + * Extracts (label, feature, weight) from input dataset. + */ + protected[ml] def extractWeightedLabeledPoints( + dataset: DataFrame): RDD[(Double, Double, Double)] = { + val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { + val idx = $(featureIndex) + val extract = udf { v: Vector => v(idx) } + extract(col($(featuresCol))) + } else { + col($(featuresCol)) + } + val w = if (hasWeightCol) { + col($(weightCol)) + } else { + lit(1.0) + } + dataset.select(col($(labelCol)), f, w) + .map { case Row(label: Double, feature: Double, weights: Double) => + (label, feature, weights) + } + } + + /** + * Validates and transforms input schema. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected[ml] def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + if (fitting) { + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + if (hasWeightCol) { + SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType) + } else { + logInfo("The weight column is not defined. Treat all instance weights as 1.0.") + } + } + val featuresType = schema($(featuresCol)).dataType + require(featuresType == DoubleType || featuresType.isInstanceOf[VectorUDT]) + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } } /** @@ -67,52 +136,46 @@ private[regression] trait IsotonicRegressionParams extends PredictorParams { * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ @Experimental -class IsotonicRegression(override val uid: String) - extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel] - with IsotonicRegressionParams { +class IsotonicRegression(override val uid: String) extends Estimator[IsotonicRegressionModel] + with IsotonicRegressionBase { def this() = this(Identifiable.randomUID("isoReg")) - /** - * Set the isotonic parameter. - * Default is true. - * @group setParam - */ - def setIsotonicParam(value: Boolean): this.type = set(isotonic, value) - setDefault(isotonic -> true) + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) - /** - * Set weight column param. - * Default is weight. - * @group setParam - */ - def setWeightParam(value: String): this.type = set(weightCol, value) - setDefault(weightCol -> "weight") + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) - override private[ml] def featuresDataType: DataType = DoubleType + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + /** @group setParam */ + def setIsotonic(value: Boolean): this.type = set(isotonic, value) - private[this] def extractWeightedLabeledPoints( - dataset: DataFrame): RDD[(Double, Double, Double)] = { + /** @group setParam */ + def setWeightCol(value: String): this.type = set(weightCol, value) - dataset.select($(labelCol), $(featuresCol), $(weightCol)) - .map { case Row(label: Double, features: Double, weights: Double) => - (label, features, weights) - } - } + /** @group setParam */ + def setFeatureIndex(value: Int): this.type = set(featureIndex, value) - override protected def train(dataset: DataFrame): IsotonicRegressionModel = { - SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType) + override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + + override def fit(dataset: DataFrame): IsotonicRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) - val parentModel = isotonicRegression.run(instances) + val oldModel = isotonicRegression.run(instances) - new IsotonicRegressionModel(uid, parentModel) + copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) } } @@ -123,22 +186,49 @@ class IsotonicRegression(override val uid: String) * * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]]. * - * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] - * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. + * @param oldModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] + * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ +@Experimental class IsotonicRegressionModel private[ml] ( override val uid: String, - private[ml] val parentModel: MLlibIsotonicRegressionModel) - extends RegressionModel[Double, IsotonicRegressionModel] - with IsotonicRegressionParams { + private val oldModel: MLlibIsotonicRegressionModel) + extends Model[IsotonicRegressionModel] with IsotonicRegressionBase { - override def featuresDataType: DataType = DoubleType + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) - override protected def predict(features: Double): Double = { - parentModel.predict(features) - } + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setFeatureIndex(value: Int): this.type = set(featureIndex, value) + + /** Boundaries in increasing order for which predictions are known. */ + def boundaries: Vector = Vectors.dense(oldModel.boundaries) + + /** + * Predictions associated with the boundaries at the same index, monotone because of isotonic + * regression. + */ + def predictions: Vector = Vectors.dense(oldModel.predictions) override def copy(extra: ParamMap): IsotonicRegressionModel = { - copyValues(new IsotonicRegressionModel(uid, parentModel), extra) + copyValues(new IsotonicRegressionModel(uid, oldModel), extra) + } + + override def transform(dataset: DataFrame): DataFrame = { + val predict = dataset.schema($(featuresCol)).dataType match { + case DoubleType => + udf { feature: Double => oldModel.predict(feature) } + case _: VectorUDT => + val idx = $(featureIndex) + udf { features: Vector => oldModel.predict(features(idx)) } + } + dataset.withColumn($(predictionCol), predict(col($(featuresCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 66e4b170bae8..c0ab00b68a2f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,57 +19,46 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row} class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { - private val schema = StructType( - Array( - StructField("label", DoubleType), - StructField("features", DoubleType), - StructField("weight", DoubleType))) - - private val predictionSchema = StructType(Array(StructField("features", DoubleType))) - private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { - val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d)) - val parallelData = sc.parallelize(data) - - sqlContext.createDataFrame(parallelData, schema) + sqlContext.createDataFrame( + labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } + ).toDF("label", "features", "weight") } private def generatePredictionInput(features: Seq[Double]): DataFrame = { - val data = Seq.tabulate(features.size)(i => Row(features(i))) - - val parallelData = sc.parallelize(data) - sqlContext.createDataFrame(parallelData, predictionSchema) + sqlContext.createDataFrame(features.map(Tuple1.apply)) + .toDF("features") } test("isotonic regression predictions") { val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) - val trainer = new IsotonicRegression().setIsotonicParam(true) + val ir = new IsotonicRegression().setIsotonic(true) - val model = trainer.fit(dataset) + val model = ir.fit(dataset) val predictions = model .transform(dataset) - .select("prediction").map { - case Row(pred) => pred + .select("prediction").map { case Row(pred) => + pred }.collect() assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) - assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) - assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) - assert(model.parentModel.isotonic) + assert(model.boundaries === Vectors.dense(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.predictions === Vectors.dense(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.getIsotonic) } test("antitonic regression predictions") { val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1)) - val trainer = new IsotonicRegression().setIsotonicParam(false) + val ir = new IsotonicRegression().setIsotonic(false) - val model = trainer.fit(dataset) + val model = ir.fit(dataset) val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) val predictions = model @@ -94,9 +83,10 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val ir = new IsotonicRegression() assert(ir.getLabelCol === "label") assert(ir.getFeaturesCol === "features") - assert(ir.getWeightCol === "weight") assert(ir.getPredictionCol === "prediction") - assert(ir.getIsotonicParam === true) + assert(!ir.isDefined(ir.weightCol)) + assert(ir.getIsotonic) + assert(ir.getFeatureIndex === 0) val model = ir.fit(dataset) model.transform(dataset) @@ -105,21 +95,22 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.getLabelCol === "label") assert(model.getFeaturesCol === "features") - assert(model.getWeightCol === "weight") assert(model.getPredictionCol === "prediction") - assert(model.getIsotonicParam === true) + assert(!model.isDefined(model.weightCol)) + assert(model.getIsotonic) + assert(model.getFeatureIndex === 0) assert(model.hasParent) } test("set parameters") { val isotonicRegression = new IsotonicRegression() - .setIsotonicParam(false) - .setWeightParam("w") + .setIsotonic(false) + .setWeightCol("w") .setFeaturesCol("f") .setLabelCol("l") .setPredictionCol("p") - assert(isotonicRegression.getIsotonicParam === false) + assert(!isotonicRegression.getIsotonic) assert(isotonicRegression.getWeightCol === "w") assert(isotonicRegression.getFeaturesCol === "f") assert(isotonicRegression.getLabelCol === "l") @@ -130,7 +121,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val dataset = generateIsotonicInput(Seq(1, 2, 3)) intercept[IllegalArgumentException] { - new IsotonicRegression().setWeightParam("w").fit(dataset) + new IsotonicRegression().setWeightCol("w").fit(dataset) } intercept[IllegalArgumentException] { @@ -145,4 +136,27 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset) } } + + test("vector features column with feature index") { + val dataset = sqlContext.createDataFrame(Seq( + (4.0, Vectors.dense(0.0, 1.0)), + (3.0, Vectors.dense(0.0, 2.0)), + (5.0, Vectors.sparse(2, Array(1), Array(3.0)))) + ).toDF("label", "features") + + val ir = new IsotonicRegression() + .setFeatureIndex(1) + + val model = ir.fit(dataset) + + val features = generatePredictionInput(Seq(2.0, 3.0, 4.0, 5.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) + } } From abfedb9cd70af60c8290bd2f5a5cec1047845ba0 Mon Sep 17 00:00:00 2001 From: Christian Kadner Date: Thu, 6 Aug 2015 14:15:42 -0700 Subject: [PATCH 0093/1824] [SPARK-9211] [SQL] [TEST] normalize line separators before generating MD5 hash The golden answer file names for the existing Hive comparison tests were generated using a MD5 hash of the query text which uses Unix-style line separator characters `\n` (LF). This PR ensures that all occurrences of the Windows-style line separator `\r\n` (CR) are replaced with `\n` (LF) before generating the MD5 hash to produce an identical MD5 hash for golden answer file names generated on Windows. Author: Christian Kadner Closes #7563 from ckadner/SPARK-9211_working and squashes the following commits: d541db0 [Christian Kadner] [SPARK-9211][SQL] normalize line separators before MD5 hash --- .../spark/sql/hive/execution/HiveComparisonTest.scala | 2 +- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 638b9c810372..2bdb0e11878e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -124,7 +124,7 @@ abstract class HiveComparisonTest protected val cacheDigest = java.security.MessageDigest.getInstance("MD5") protected def getMd5(str: String): String = { val digest = java.security.MessageDigest.getInstance("MD5") - digest.update(str.getBytes("utf-8")) + digest.update(str.replaceAll(System.lineSeparator(), "\n").getBytes("utf-8")) new java.math.BigInteger(1, digest.digest).toString(16) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index edb27553671d..83f9f3eaa3a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -427,7 +427,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |USING 'cat' AS (tKey, tValue) ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) test("transform with SerDe2") { @@ -446,7 +446,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('avro.schema.literal'='{"namespace": "testing.hive.avro.serde","name": |"src","type": "record","fields": [{"name":"key","type":"int"}]}') |FROM small_src - """.stripMargin.replaceAll("\n", " ")).collect().head + """.stripMargin.replaceAll(System.lineSeparator(), " ")).collect().head assert(expected(0) === res(0)) } @@ -458,7 +458,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' AS (tKey, tValue) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES ('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) createQueryTest("transform with SerDe4", """ @@ -467,7 +467,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |('serialization.last.column.takes.rest'='true') USING 'cat' ROW FORMAT SERDE |'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES |('serialization.last.column.takes.rest'='true') FROM src; - """.stripMargin.replaceAll("\n", " ")) + """.stripMargin.replaceAll(System.lineSeparator(), " ")) createQueryTest("LIKE", "SELECT * FROM src WHERE value LIKE '%1%'") From 21fdfd7d6f89adbd37066c169e6ba9ccd337683e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 6 Aug 2015 14:33:29 -0700 Subject: [PATCH 0094/1824] [SPARK-9548][SQL] Add a destructive iterator for BytesToBytesMap This pull request adds a destructive iterator to BytesToBytesMap. When used, the iterator frees pages as it traverses them. This is part of the effort to avoid starving when we have more than one operators that can exhaust memory. This is based on #7924, but fixes a bug there (Don't use destructive iterator in UnsafeKVExternalSorter). Closes #7924. Author: Liang-Chi Hsieh Author: Reynold Xin Closes #8003 from rxin/map-destructive-iterator and squashes the following commits: 6b618c3 [Reynold Xin] Don't use destructive iterator in UnsafeKVExternalSorter. a7bd8ec [Reynold Xin] Merge remote-tracking branch 'viirya/destructive_iter' into map-destructive-iterator 7652083 [Liang-Chi Hsieh] For comments: add destructiveIterator(), modify unit test, remove code block. 4a3e9de [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into destructive_iter 581e9e3 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into destructive_iter f0ff783 [Liang-Chi Hsieh] No need to free last page. 9e9d2a3 [Liang-Chi Hsieh] Add a destructive iterator for BytesToBytesMap. --- .../spark/unsafe/map/BytesToBytesMap.java | 33 +++++++++++++++-- .../map/AbstractBytesToBytesMapSuite.java | 37 ++++++++++++++++--- .../UnsafeFixedWidthAggregationMap.java | 7 +++- .../sql/execution/UnsafeKVExternalSorter.java | 5 ++- 4 files changed, 71 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 20347433e16b..5ac3736ac62a 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -227,22 +227,35 @@ public static final class BytesToBytesMapIterator implements Iterator private final Iterator dataPagesIterator; private final Location loc; - private MemoryBlock currentPage; + private MemoryBlock currentPage = null; private int currentRecordNumber = 0; private Object pageBaseObject; private long offsetInPage; + // If this iterator destructive or not. When it is true, it frees each page as it moves onto + // next one. + private boolean destructive = false; + private BytesToBytesMap bmap; + private BytesToBytesMapIterator( - int numRecords, Iterator dataPagesIterator, Location loc) { + int numRecords, Iterator dataPagesIterator, Location loc, + boolean destructive, BytesToBytesMap bmap) { this.numRecords = numRecords; this.dataPagesIterator = dataPagesIterator; this.loc = loc; + this.destructive = destructive; + this.bmap = bmap; if (dataPagesIterator.hasNext()) { advanceToNextPage(); } } private void advanceToNextPage() { + if (destructive && currentPage != null) { + dataPagesIterator.remove(); + this.bmap.taskMemoryManager.freePage(currentPage); + this.bmap.shuffleMemoryManager.release(currentPage.size()); + } currentPage = dataPagesIterator.next(); pageBaseObject = currentPage.getBaseObject(); offsetInPage = currentPage.getBaseOffset(); @@ -281,7 +294,21 @@ public void remove() { * `lookup()`, the behavior of the returned iterator is undefined. */ public BytesToBytesMapIterator iterator() { - return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc); + return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this); + } + + /** + * Returns a destructive iterator for iterating over the entries of this map. It frees each page + * as it moves onto next one. Notice: it is illegal to call any method on the map after + * `destructiveIterator()` has been called. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public BytesToBytesMapIterator destructiveIterator() { + return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this); } /** diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 0e23a64fb74b..3c5003380162 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -183,8 +183,7 @@ public void setAndRetrieveAKey() { } } - @Test - public void iteratorTest() throws Exception { + private void iteratorTestBase(boolean destructive) throws Exception { final int size = 4096; BytesToBytesMap map = new BytesToBytesMap( taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES); @@ -216,7 +215,14 @@ public void iteratorTest() throws Exception { } } final java.util.BitSet valuesSeen = new java.util.BitSet(size); - final Iterator iter = map.iterator(); + final Iterator iter; + if (destructive) { + iter = map.destructiveIterator(); + } else { + iter = map.iterator(); + } + int numPages = map.getNumDataPages(); + int countFreedPages = 0; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); @@ -228,11 +234,22 @@ public void iteratorTest() throws Exception { if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = PlatformDependent.UNSAFE.getLong( + keyAddress.getBaseObject(), keyAddress.getBaseOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); + if (destructive) { + // The iterator moves onto next page and frees previous page + if (map.getNumDataPages() < numPages) { + numPages = map.getNumDataPages(); + countFreedPages++; + } + } + } + if (destructive) { + // Latest page is not freed by iterator but by map itself + Assert.assertEquals(countFreedPages, numPages - 1); } Assert.assertEquals(size, valuesSeen.cardinality()); } finally { @@ -240,6 +257,16 @@ public void iteratorTest() throws Exception { } } + @Test + public void iteratorTest() throws Exception { + iteratorTestBase(false); + } + + @Test + public void destructiveIteratorTest() throws Exception { + iteratorTestBase(true); + } + @Test public void iteratingOverDataPagesWithWastedSpace() throws Exception { final int NUM_ENTRIES = 1000 * 1000; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 02458030b00e..efb33530dac8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -154,14 +154,17 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { } /** - * Returns an iterator over the keys and values in this map. + * Returns an iterator over the keys and values in this map. This uses destructive iterator of + * BytesToBytesMap. So it is illegal to call any other method on this map after `iterator()` has + * been called. * * For efficiency, each call returns the same object. */ public KVIterator iterator() { return new KVIterator() { - private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = map.iterator(); + private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = + map.destructiveIterator(); private final UnsafeRow key = new UnsafeRow(); private final UnsafeRow value = new UnsafeRow(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 6c1cf136d9b8..9a65c9d3a404 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -88,8 +88,11 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); - final int numKeyFields = keySchema.size(); + // We cannot use the destructive iterator here because we are reusing the existing memory + // pages in BytesToBytesMap to hold records during sorting. + // The only new memory we are allocating is the pointer/prefix array. BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); + final int numKeyFields = keySchema.size(); UnsafeRow row = new UnsafeRow(); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); From 0a078303d08ad2bb92b9a8a6969563d75b512290 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 6 Aug 2015 14:35:30 -0700 Subject: [PATCH 0095/1824] [SPARK-9556] [SPARK-9619] [SPARK-9624] [STREAMING] Make BlockGenerator more robust and make all BlockGenerators subscribe to rate limit updates In some receivers, instead of using the default `BlockGenerator` in `ReceiverSupervisorImpl`, custom generator with their custom listeners are used for reliability (see [`ReliableKafkaReceiver`](https://github.com/apache/spark/blob/master/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala#L99) and [updated `KinesisReceiver`](https://github.com/apache/spark/pull/7825/files)). These custom generators do not receive rate updates. This PR modifies the code to allow custom `BlockGenerator`s to be created through the `ReceiverSupervisorImpl` so that they can be kept track and rate updates can be applied. In the process, I did some simplification, and de-flaki-fication of some rate controller related tests. In particular. - Renamed `Receiver.executor` to `Receiver.supervisor` (to match `ReceiverSupervisor`) - Made `RateControllerSuite` faster (by increasing batch interval) and less flaky - Changed a few internal API to return the current rate of block generators as Long instead of Option\[Long\] (was inconsistent at places). - Updated existing `ReceiverTrackerSuite` to test that custom block generators get rate updates as well. Author: Tathagata Das Closes #7913 from tdas/SPARK-9556 and squashes the following commits: 41d4461 [Tathagata Das] fix scala style eb9fd59 [Tathagata Das] Updated kinesis receiver d24994d [Tathagata Das] Updated BlockGeneratorSuite to use manual clock in BlockGenerator d70608b [Tathagata Das] Updated BlockGenerator with states and proper synchronization f6bd47e [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9556 31da173 [Tathagata Das] Fix bug 12116df [Tathagata Das] Add BlockGeneratorSuite 74bd069 [Tathagata Das] Fix style 989bb5c [Tathagata Das] Made BlockGenerator fail is used after stop, and added better unit tests for it 3ff618c [Tathagata Das] Fix test b40eff8 [Tathagata Das] slight refactoring f0df0f1 [Tathagata Das] Scala style fixes 51759cb [Tathagata Das] Refactored rate controller tests and added the ability to update rate of any custom block generator --- .../org/apache/spark/util/ManualClock.scala | 2 +- .../kafka/ReliableKafkaReceiver.scala | 2 +- .../streaming/kinesis/KinesisReceiver.scala | 2 +- .../streaming/receiver/ActorReceiver.scala | 8 +- .../streaming/receiver/BlockGenerator.scala | 131 ++++++--- .../streaming/receiver/RateLimiter.scala | 3 +- .../spark/streaming/receiver/Receiver.scala | 52 ++-- .../receiver/ReceiverSupervisor.scala | 27 +- .../receiver/ReceiverSupervisorImpl.scala | 33 ++- .../spark/streaming/CheckpointSuite.scala | 16 +- .../spark/streaming/ReceiverSuite.scala | 31 +-- .../receiver/BlockGeneratorSuite.scala | 253 ++++++++++++++++++ .../scheduler/RateControllerSuite.scala | 64 ++--- .../scheduler/ReceiverTrackerSuite.scala | 129 +++++---- 14 files changed, 534 insertions(+), 219 deletions(-) create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala index 171855406198..e7a65d74a440 100644 --- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala +++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala @@ -58,7 +58,7 @@ private[spark] class ManualClock(private var time: Long) extends Clock { */ def waitTillTime(targetTime: Long): Long = synchronized { while (time < targetTime) { - wait(100) + wait(10) } getTimeMillis() } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index 75f0dfc22b9d..764d170934aa 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -96,7 +96,7 @@ class ReliableKafkaReceiver[ blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]() // Initialize the block generator for storing Kafka message. - blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf) + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") { logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " + diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index a4baeec0846b..22324e821ce9 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -136,7 +136,7 @@ private[kinesis] class KinesisReceiver( * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ override def onStart() { - blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, SparkEnv.get.conf) + blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) workerId = Utils.localHostName() + ":" + UUID.randomUUID() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index cd309788a771..7ec74016a1c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -144,7 +144,7 @@ private[streaming] class ActorReceiver[T: ClassTag]( receiverSupervisorStrategy: SupervisorStrategy ) extends Receiver[T](storageLevel) with Logging { - protected lazy val supervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), + protected lazy val actorSupervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), "Supervisor" + streamId) class Supervisor extends Actor { @@ -191,11 +191,11 @@ private[streaming] class ActorReceiver[T: ClassTag]( } def onStart(): Unit = { - supervisor - logInfo("Supervision tree for receivers initialized at:" + supervisor.path) + actorSupervisor + logInfo("Supervision tree for receivers initialized at:" + actorSupervisor.path) } def onStop(): Unit = { - supervisor ! PoisonPill + actorSupervisor ! PoisonPill } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 92b51ce39234..794dece370b2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -21,10 +21,10 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{Clock, SystemClock} /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { @@ -69,16 +69,35 @@ private[streaming] trait BlockGeneratorListener { * named blocks at regular intervals. This class starts two threads, * one to periodically start a new batch and prepare the previous batch of as a block, * the other to push the blocks into the block manager. + * + * Note: Do not create BlockGenerator instances directly inside receivers. Use + * `ReceiverSupervisor.createBlockGenerator` to create a BlockGenerator and use it. */ private[streaming] class BlockGenerator( listener: BlockGeneratorListener, receiverId: Int, - conf: SparkConf + conf: SparkConf, + clock: Clock = new SystemClock() ) extends RateLimiter(conf) with Logging { private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) - private val clock = new SystemClock() + /** + * The BlockGenerator can be in 5 possible states, in the order as follows. + * - Initialized: Nothing has been started + * - Active: start() has been called, and it is generating blocks on added data. + * - StoppedAddingData: stop() has been called, the adding of data has been stopped, + * but blocks are still being generated and pushed. + * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but + * they are still being pushed. + * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + */ + private object GeneratorState extends Enumeration { + type GeneratorState = Value + val Initialized, Active, StoppedAddingData, StoppedGeneratingBlocks, StoppedAll = Value + } + import GeneratorState._ + private val blockIntervalMs = conf.getTimeAsMs("spark.streaming.blockInterval", "200ms") require(blockIntervalMs > 0, s"'spark.streaming.blockInterval' should be a positive value") @@ -89,59 +108,100 @@ private[streaming] class BlockGenerator( private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } @volatile private var currentBuffer = new ArrayBuffer[Any] - @volatile private var stopped = false + @volatile private var state = Initialized /** Start block generating and pushing threads. */ - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Started BlockGenerator") + def start(): Unit = synchronized { + if (state == Initialized) { + state = Active + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Started BlockGenerator") + } else { + throw new SparkException( + s"Cannot start BlockGenerator as its not in the Initialized state [state = $state]") + } } - /** Stop all threads. */ - def stop() { + /** + * Stop everything in the right order such that all the data added is pushed out correctly. + * - First, stop adding data to the current buffer. + * - Second, stop generating blocks. + * - Finally, wait for queue of to-be-pushed blocks to be drained. + */ + def stop(): Unit = { + // Set the state to stop adding data + synchronized { + if (state == Active) { + state = StoppedAddingData + } else { + logWarning(s"Cannot stop BlockGenerator as its not in the Active state [state = $state]") + return + } + } + + // Stop generating blocks and set the state for block pushing thread to start draining the queue logInfo("Stopping BlockGenerator") blockIntervalTimer.stop(interruptTimer = false) - stopped = true - logInfo("Waiting for block pushing thread") + synchronized { state = StoppedGeneratingBlocks } + + // Wait for the queue to drain and mark generated as stopped + logInfo("Waiting for block pushing thread to terminate") blockPushingThread.join() + synchronized { state = StoppedAll } logInfo("Stopped BlockGenerator") } /** - * Push a single data item into the buffer. All received data items - * will be periodically pushed into BlockManager. + * Push a single data item into the buffer. */ - def addData (data: Any): Unit = synchronized { - waitToPush() - currentBuffer += data + def addData(data: Any): Unit = synchronized { + if (state == Active) { + waitToPush() + currentBuffer += data + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } /** * Push a single data item into the buffer. After buffering the data, the - * `BlockGeneratorListener.onAddData` callback will be called. All received data items - * will be periodically pushed into BlockManager. + * `BlockGeneratorListener.onAddData` callback will be called. */ def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { - waitToPush() - currentBuffer += data - listener.onAddData(data, metadata) + if (state == Active) { + waitToPush() + currentBuffer += data + listener.onAddData(data, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } /** * Push multiple data items into the buffer. After buffering the data, the - * `BlockGeneratorListener.onAddData` callback will be called. All received data items - * will be periodically pushed into BlockManager. Note that all the data items is guaranteed - * to be present in a single block. + * `BlockGeneratorListener.onAddData` callback will be called. Note that all the data items + * are atomically added to the buffer, and are hence guaranteed to be present in a single block. */ def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { - dataIterator.foreach { data => - waitToPush() - currentBuffer += data + if (state == Active) { + dataIterator.foreach { data => + waitToPush() + currentBuffer += data + } + listener.onAddData(dataIterator, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") } - listener.onAddData(dataIterator, metadata) } + def isActive(): Boolean = state == Active + + def isStopped(): Boolean = state == StoppedAll + /** Change the buffer to which single records are added to. */ private def updateCurrentBuffer(time: Long): Unit = synchronized { try { @@ -165,18 +225,21 @@ private[streaming] class BlockGenerator( /** Keep pushing blocks to the BlockManager. */ private def keepPushingBlocks() { logInfo("Started block pushing thread") + + def isGeneratingBlocks = synchronized { state == Active || state == StoppedAddingData } try { - while (!stopped) { - Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + while (isGeneratingBlocks) { + Option(blocksForPushing.poll(10, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => } } - // Push out the blocks that are still left + + // At this point, state is StoppedGeneratingBlock. So drain the queue of to-be-pushed blocks. logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") while (!blocksForPushing.isEmpty) { - logDebug("Getting block ") val block = blocksForPushing.take() + logDebug(s"Pushing block $block") pushBlock(block) logInfo("Blocks left to push " + blocksForPushing.size()) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index f663def4c051..bca1fbc8fda2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -45,8 +45,7 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { /** * Return the current rate limit. If no limit has been set so far, it returns {{{Long.MaxValue}}}. */ - def getCurrentLimit: Long = - rateLimiter.getRate.toLong + def getCurrentLimit: Long = rateLimiter.getRate.toLong /** * Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 7504fa44d9fa..554aae0117b2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -116,12 +116,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * being pushed into Spark's memory. */ def store(dataItem: T) { - executor.pushSingle(dataItem) + supervisor.pushSingle(dataItem) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def store(dataBuffer: ArrayBuffer[T]) { - executor.pushArrayBuffer(dataBuffer, None, None) + supervisor.pushArrayBuffer(dataBuffer, None, None) } /** @@ -130,12 +130,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataBuffer: ArrayBuffer[T], metadata: Any) { - executor.pushArrayBuffer(dataBuffer, Some(metadata), None) + supervisor.pushArrayBuffer(dataBuffer, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator, None, None) } /** @@ -144,12 +144,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: java.util.Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: java.util.Iterator[T]) { - executor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator, None, None) } /** @@ -158,7 +158,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: Iterator[T], metadata: Any) { - executor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator, Some(metadata), None) } /** @@ -167,7 +167,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * that Spark is configured to use. */ def store(bytes: ByteBuffer) { - executor.pushBytes(bytes, None, None) + supervisor.pushBytes(bytes, None, None) } /** @@ -176,12 +176,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(bytes: ByteBuffer, metadata: Any) { - executor.pushBytes(bytes, Some(metadata), None) + supervisor.pushBytes(bytes, Some(metadata), None) } /** Report exceptions in receiving data. */ def reportError(message: String, throwable: Throwable) { - executor.reportError(message, throwable) + supervisor.reportError(message, throwable) } /** @@ -193,7 +193,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` will be reported to the driver. */ def restart(message: String) { - executor.restartReceiver(message) + supervisor.restartReceiver(message) } /** @@ -205,7 +205,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * The `message` and `exception` will be reported to the driver. */ def restart(message: String, error: Throwable) { - executor.restartReceiver(message, Some(error)) + supervisor.restartReceiver(message, Some(error)) } /** @@ -215,22 +215,22 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * in a background thread. */ def restart(message: String, error: Throwable, millisecond: Int) { - executor.restartReceiver(message, Some(error), millisecond) + supervisor.restartReceiver(message, Some(error), millisecond) } /** Stop the receiver completely. */ def stop(message: String) { - executor.stop(message, None) + supervisor.stop(message, None) } /** Stop the receiver completely due to an exception */ def stop(message: String, error: Throwable) { - executor.stop(message, Some(error)) + supervisor.stop(message, Some(error)) } /** Check if the receiver has started or not. */ def isStarted(): Boolean = { - executor.isReceiverStarted() + supervisor.isReceiverStarted() } /** @@ -238,7 +238,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * the receiving of data should be stopped. */ def isStopped(): Boolean = { - executor.isReceiverStopped() + supervisor.isReceiverStopped() } /** @@ -257,7 +257,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable private var id: Int = -1 /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ - private[streaming] var executor_ : ReceiverSupervisor = null + @transient private var _supervisor : ReceiverSupervisor = null /** Set the ID of the DStream that this receiver is associated with. */ private[streaming] def setReceiverId(id_ : Int) { @@ -265,15 +265,17 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable } /** Attach Network Receiver executor to this receiver. */ - private[streaming] def attachExecutor(exec: ReceiverSupervisor) { - assert(executor_ == null) - executor_ = exec + private[streaming] def attachSupervisor(exec: ReceiverSupervisor) { + assert(_supervisor == null) + _supervisor = exec } - /** Get the attached executor. */ - private def executor: ReceiverSupervisor = { - assert(executor_ != null, "Executor has not been attached to this receiver") - executor_ + /** Get the attached supervisor. */ + private[streaming] def supervisor: ReceiverSupervisor = { + assert(_supervisor != null, + "A ReceiverSupervisor have not been attached to the receiver yet. Maybe you are starting " + + "some computation in the receiver before the Receiver.onStart() has been called.") + _supervisor } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index e98017a63756..158d1ba2f183 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -44,8 +44,8 @@ private[streaming] abstract class ReceiverSupervisor( } import ReceiverState._ - // Attach the executor to the receiver - receiver.attachExecutor(this) + // Attach the supervisor to the receiver + receiver.attachSupervisor(this) private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128)) @@ -60,7 +60,7 @@ private[streaming] abstract class ReceiverSupervisor( private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) /** The current maximum rate limit for this receiver. */ - private[streaming] def getCurrentRateLimit: Option[Long] = None + private[streaming] def getCurrentRateLimit: Long = Long.MaxValue /** Exception associated with the stopping of the receiver */ @volatile protected var stoppingError: Throwable = null @@ -92,13 +92,30 @@ private[streaming] abstract class ReceiverSupervisor( optionalBlockId: Option[StreamBlockId] ) + /** + * Create a custom [[BlockGenerator]] that the receiver implementation can directly control + * using their provided [[BlockGeneratorListener]]. + * + * Note: Do not explicitly start or stop the `BlockGenerator`, the `ReceiverSupervisorImpl` + * will take care of it. + */ + def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator + /** Report errors. */ def reportError(message: String, throwable: Throwable) - /** Called when supervisor is started */ + /** + * Called when supervisor is started. + * Note that this must be called before the receiver.onStart() is called to ensure + * things like [[BlockGenerator]]s are started before the receiver starts sending data. + */ protected def onStart() { } - /** Called when supervisor is stopped */ + /** + * Called when supervisor is stopped. + * Note that this must be called after the receiver.onStop() is called to ensure + * things like [[BlockGenerator]]s are cleaned up after the receiver stops sending data. + */ protected def onStop(message: String, error: Option[Throwable]) { } /** Called when receiver is started. Return true if the driver accepts us */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 0d802f83549a..59ef58d232ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -81,15 +82,20 @@ private[streaming] class ReceiverSupervisorImpl( cleanupOldBlocks(threshTime) case UpdateRateLimit(eps) => logInfo(s"Received a new rate limit: $eps.") - blockGenerator.updateRate(eps) + registeredBlockGenerators.foreach { bg => + bg.updateRate(eps) + } } }) /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) + private val registeredBlockGenerators = new mutable.ArrayBuffer[BlockGenerator] + with mutable.SynchronizedBuffer[BlockGenerator] + /** Divides received data records into data blocks for pushing in BlockManager. */ - private val blockGenerator = new BlockGenerator(new BlockGeneratorListener { + private val defaultBlockGeneratorListener = new BlockGeneratorListener { def onAddData(data: Any, metadata: Any): Unit = { } def onGenerateBlock(blockId: StreamBlockId): Unit = { } @@ -101,14 +107,15 @@ private[streaming] class ReceiverSupervisorImpl( def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) { pushArrayBuffer(arrayBuffer, None, Some(blockId)) } - }, streamId, env.conf) + } + private val defaultBlockGenerator = createBlockGenerator(defaultBlockGeneratorListener) - override private[streaming] def getCurrentRateLimit: Option[Long] = - Some(blockGenerator.getCurrentLimit) + /** Get the current rate limit of the default block generator */ + override private[streaming] def getCurrentRateLimit: Long = defaultBlockGenerator.getCurrentLimit /** Push a single record of received data into block generator. */ def pushSingle(data: Any) { - blockGenerator.addData(data) + defaultBlockGenerator.addData(data) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ @@ -162,11 +169,11 @@ private[streaming] class ReceiverSupervisorImpl( } override protected def onStart() { - blockGenerator.start() + registeredBlockGenerators.foreach { _.start() } } override protected def onStop(message: String, error: Option[Throwable]) { - blockGenerator.stop() + registeredBlockGenerators.foreach { _.stop() } env.rpcEnv.stop(endpoint) } @@ -183,6 +190,16 @@ private[streaming] class ReceiverSupervisorImpl( logInfo("Stopped receiver " + streamId) } + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + // Cleanup BlockGenerators that have already been stopped + registeredBlockGenerators --= registeredBlockGenerators.filter{ _.isStopped() } + + val newBlockGenerator = new BlockGenerator(blockGeneratorListener, streamId, env.conf) + registeredBlockGenerators += newBlockGenerator + newBlockGenerator + } + /** Generate new block ID */ private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 67c2d900940a..1bba7a143edf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming import java.io.File -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag import com.google.common.base.Charsets @@ -33,7 +33,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} -import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver} +import org.apache.spark.streaming.scheduler.{ConstantEstimator, RateTestInputDStream, RateTestReceiver} import org.apache.spark.util.{Clock, ManualClock, Utils} /** @@ -397,26 +397,24 @@ class CheckpointSuite extends TestSuiteBase { ssc = new StreamingContext(conf, batchDuration) ssc.checkpoint(checkpointDir) - val dstream = new RateLimitInputDStream(ssc) { + val dstream = new RateTestInputDStream(ssc) { override val rateController = - Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + Some(new ReceiverRateController(id, new ConstantEstimator(200))) } - SingletonTestRateReceiver.reset() val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2)) output.register() runStreams(ssc, 5, 5) - SingletonTestRateReceiver.reset() ssc = new StreamingContext(checkpointDir) ssc.start() val outputNew = advanceTimeWithRealDelay(ssc, 2) - eventually(timeout(5.seconds)) { - assert(dstream.getCurrentRateLimit === Some(200)) + eventually(timeout(10.seconds)) { + assert(RateTestReceiver.getActive().nonEmpty) + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === 200) } ssc.stop() - ssc = null } // This tests whether file input stream remembers what files were seen before diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 13b4d17c8618..01279b34f73d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -129,32 +129,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } } - test("block generator") { - val blockGeneratorListener = new FakeBlockGeneratorListener - val blockIntervalMs = 200 - val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") - val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) - val expectedBlocks = 5 - val waitTime = expectedBlocks * blockIntervalMs + (blockIntervalMs / 2) - val generatedData = new ArrayBuffer[Int] - - // Generate blocks - val startTime = System.currentTimeMillis() - blockGenerator.start() - var count = 0 - while(System.currentTimeMillis - startTime < waitTime) { - blockGenerator.addData(count) - generatedData += count - count += 1 - Thread.sleep(10) - } - blockGenerator.stop() - - val recordedData = blockGeneratorListener.arrayBuffers.flatten - assert(blockGeneratorListener.arrayBuffers.size > 0) - assert(recordedData.toSet === generatedData.toSet) - } - ignore("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener val blockIntervalMs = 100 @@ -348,6 +322,11 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { } override protected def onReceiverStart(): Boolean = true + + override def createBlockGenerator( + blockGeneratorListener: BlockGeneratorListener): BlockGenerator = { + null + } } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala new file mode 100644 index 000000000000..a38cc603f219 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import scala.collection.mutable + +import org.scalatest.BeforeAndAfter +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.util.ManualClock +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} + +class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { + + private val blockIntervalMs = 10 + private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms") + @volatile private var blockGenerator: BlockGenerator = null + + after { + if (blockGenerator != null) { + blockGenerator.stop() + } + } + + test("block generation and data callbacks") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + + require(blockIntervalMs > 5) + require(listener.onAddDataCalled === false) + require(listener.onGenerateBlockCalled === false) + require(listener.onPushBlockCalled === false) + + // Verify that creating the generator does not start it + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + assert(blockGenerator.isActive() === false, "block generator active before start()") + assert(blockGenerator.isStopped() === false, "block generator stopped before start()") + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + + // Verify start marks the generator active, but does not call the callbacks + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator active after start()") + assert(blockGenerator.isStopped() === false, "block generator stopped after start()") + withClue("callbacks called before adding data") { + assert(listener.onAddDataCalled === false) + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + + // Verify whether addData() adds data that is present in generated blocks + val data1 = 1 to 10 + data1.foreach { blockGenerator.addData _ } + withClue("callbacks called on adding data without metadata and without block generation") { + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + assert(listener.onGenerateBlockCalled === false) + assert(listener.onPushBlockCalled === false) + } + clock.advance(blockIntervalMs) // advance clock to generate blocks + withClue("blocks not generated or pushed") { + eventually(timeout(1 second)) { + assert(listener.onGenerateBlockCalled === true) + assert(listener.onPushBlockCalled === true) + } + } + listener.pushedData should contain theSameElementsInOrderAs (data1) + assert(listener.onAddDataCalled === false) // should be called only with addDataWithCallback() + + // Verify addDataWithCallback() add data+metadata and and callbacks are called correctly + val data2 = 11 to 20 + val metadata2 = data2.map { _.toString } + data2.zip(metadata2).foreach { case (d, m) => blockGenerator.addDataWithCallback(d, m) } + assert(listener.onAddDataCalled === true) + listener.addedData should contain theSameElementsInOrderAs (data2) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2) + } + + // Verify addMultipleDataWithCallback() add data+metadata and and callbacks are called correctly + val data3 = 21 to 30 + val metadata3 = "metadata" + blockGenerator.addMultipleDataWithCallback(data3.iterator, metadata3) + listener.addedMetadata should contain theSameElementsInOrderAs (metadata2 :+ metadata3) + clock.advance(blockIntervalMs) // advance clock to generate blocks + eventually(timeout(1 second)) { + listener.pushedData should contain theSameElementsInOrderAs (data1 ++ data2 ++ data3) + } + + // Stop the block generator by starting the stop on a different thread and + // then advancing the manual clock for the stopping to proceed. + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + clock.advance(blockIntervalMs) + assert(blockGenerator.isStopped() === true) + } + thread.join() + + // Verify that the generator cannot be used any more + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, 1) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), 1) + } + intercept[SparkException] { + blockGenerator.start() + } + blockGenerator.stop() // Calling stop again should be fine + } + + test("stop ensures correct shutdown") { + val listener = new TestBlockGeneratorListener + val clock = new ManualClock() + blockGenerator = new BlockGenerator(listener, 0, conf, clock) + require(listener.onGenerateBlockCalled === false) + blockGenerator.start() + assert(blockGenerator.isActive() === true, "block generator") + assert(blockGenerator.isStopped() === false) + + val data = 1 to 1000 + data.foreach { blockGenerator.addData _ } + + // Verify that stop() shutdowns everything in the right order + // - First, stop receiving new data + // - Second, wait for final block with all buffered data to be generated + // - Finally, wait for all blocks to be pushed + clock.advance(1) // to make sure that the timer for another interval to complete + val thread = stopBlockGenerator(blockGenerator) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(blockGenerator.isActive() === false) + } + assert(blockGenerator.isStopped() === false) + + // Verify that data cannot be added + intercept[SparkException] { + blockGenerator.addData(1) + } + intercept[SparkException] { + blockGenerator.addDataWithCallback(1, null) + } + intercept[SparkException] { + blockGenerator.addMultipleDataWithCallback(Iterator(1), null) + } + + // Verify that stop() stays blocked until another block containing all the data is generated + // This intercept always succeeds, as the body either will either throw a timeout exception + // (expected as stop() should never complete) or a SparkException (unexpected as stop() + // completed and thread terminated). + val exception = intercept[Exception] { + failAfter(200 milliseconds) { + thread.join() + throw new SparkException( + "BlockGenerator.stop() completed before generating timer was stopped") + } + } + exception should not be a [SparkException] + + + // Verify that the final data is present in the final generated block and + // pushed before complete stop + assert(blockGenerator.isStopped() === false) // generator has not stopped yet + clock.advance(blockIntervalMs) // force block generation + failAfter(1 second) { + thread.join() + } + assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped + assert(listener.pushedData === data, "All data not pushed by stop()") + } + + test("block push errors are reported") { + val listener = new TestBlockGeneratorListener { + @volatile var errorReported = false + override def onPushBlock( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + throw new SparkException("test") + } + override def onError(message: String, throwable: Throwable): Unit = { + errorReported = true + } + } + blockGenerator = new BlockGenerator(listener, 0, conf) + blockGenerator.start() + assert(listener.errorReported === false) + blockGenerator.addData(1) + eventually(timeout(1 second), interval(10 milliseconds)) { + assert(listener.errorReported === true) + } + blockGenerator.stop() + } + + /** + * Helper method to stop the block generator with manual clock in a different thread, + * so that the main thread can advance the clock that allows the stopping to proceed. + */ + private def stopBlockGenerator(blockGenerator: BlockGenerator): Thread = { + val thread = new Thread() { + override def run(): Unit = { + blockGenerator.stop() + } + } + thread.start() + thread + } + + /** A listener for BlockGenerator that records the data in the callbacks */ + private class TestBlockGeneratorListener extends BlockGeneratorListener { + val pushedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedData = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val addedMetadata = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + @volatile var onGenerateBlockCalled = false + @volatile var onAddDataCalled = false + @volatile var onPushBlockCalled = false + + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + pushedData ++= arrayBuffer + onPushBlockCalled = true + } + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = { + onGenerateBlockCalled = true + } + override def onAddData(data: Any, metadata: Any): Unit = { + addedData += data + addedMetadata += metadata + onAddDataCalled = true + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala index 921da773f6c1..1eb52b7029a2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -18,10 +18,7 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable -import scala.reflect.ClassTag -import scala.util.control.NonFatal -import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -32,72 +29,63 @@ class RateControllerSuite extends TestSuiteBase { override def useManualClock: Boolean = false - test("rate controller publishes updates") { + override def batchDuration: Duration = Milliseconds(50) + + test("RateController - rate controller publishes updates after batches complete") { val ssc = new StreamingContext(conf, batchDuration) withStreamingContext(ssc) { ssc => - val dstream = new RateLimitInputDStream(ssc) + val dstream = new RateTestInputDStream(ssc) dstream.register() ssc.start() eventually(timeout(10.seconds)) { - assert(dstream.publishCalls > 0) + assert(dstream.publishedRates > 0) } } } - test("publish rates reach receivers") { + test("ReceiverRateController - published rates reach receivers") { val ssc = new StreamingContext(conf, batchDuration) withStreamingContext(ssc) { ssc => - val dstream = new RateLimitInputDStream(ssc) { + val estimator = new ConstantEstimator(100) + val dstream = new RateTestInputDStream(ssc) { override val rateController = - Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + Some(new ReceiverRateController(id, estimator)) } dstream.register() - SingletonTestRateReceiver.reset() ssc.start() - eventually(timeout(10.seconds)) { - assert(dstream.getCurrentRateLimit === Some(200)) + // Wait for receiver to start + eventually(timeout(5.seconds)) { + RateTestReceiver.getActive().nonEmpty } - } - } - test("multiple publish rates reach receivers") { - val ssc = new StreamingContext(conf, batchDuration) - withStreamingContext(ssc) { ssc => - val rates = Seq(100L, 200L, 300L) - - val dstream = new RateLimitInputDStream(ssc) { - override val rateController = - Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*))) + // Update rate in the estimator and verify whether the rate was published to the receiver + def updateRateAndVerify(rate: Long): Unit = { + estimator.updateRate(rate) + eventually(timeout(5.seconds)) { + assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === rate) + } } - SingletonTestRateReceiver.reset() - dstream.register() - - val observedRates = mutable.HashSet.empty[Long] - ssc.start() - eventually(timeout(20.seconds)) { - dstream.getCurrentRateLimit.foreach(observedRates += _) - // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver - observedRates should contain theSameElementsAs (rates :+ Long.MaxValue) + // Verify multiple rate update + Seq(100, 200, 300).foreach { rate => + updateRateAndVerify(rate) } } } } -private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator { - private var idx: Int = 0 +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { - private def nextRate(): Double = { - val rate = rates(idx) - idx = (idx + 1) % rates.size - rate + def updateRate(newRate: Long): Unit = { + rate = newRate } def compute( time: Long, elements: Long, processingDelay: Long, - schedulingDelay: Long): Option[Double] = Some(nextRate()) + schedulingDelay: Long): Option[Double] = Some(rate) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index afad5f16dbc7..dd292ba4dd94 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -17,48 +17,43 @@ package org.apache.spark.streaming.scheduler +import scala.collection.mutable.ArrayBuffer + import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ -import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receiver._ /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { - val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") - - test("Receiver tracker - propagates rate limit") { - withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc => - object ReceiverStartedWaiter extends StreamingListener { - @volatile - var started = false - - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { - started = true - } - } - ssc.addStreamingListener(ReceiverStartedWaiter) + test("send rate update to receivers") { + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => ssc.scheduler.listenerBus.start(ssc.sc) - SingletonTestRateReceiver.reset() val newRateLimit = 100L - val inputDStream = new RateLimitInputDStream(ssc) + val inputDStream = new RateTestInputDStream(ssc) val tracker = new ReceiverTracker(ssc) tracker.start() try { // we wait until the Receiver has registered with the tracker, // otherwise our rate update is lost eventually(timeout(5 seconds)) { - assert(ReceiverStartedWaiter.started) + assert(RateTestReceiver.getActive().nonEmpty) } + + + // Verify that the rate of the block generator in the receiver get updated + val activeReceiver = RateTestReceiver.getActive().get tracker.sendRateUpdate(inputDStream.id, newRateLimit) - // this is an async message, we need to wait a bit for it to be processed - eventually(timeout(3 seconds)) { - assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + eventually(timeout(5 seconds)) { + assert(activeReceiver.getDefaultBlockGeneratorRateLimit() === newRateLimit, + "default block generator did not receive rate update") + assert(activeReceiver.getCustomBlockGeneratorRateLimit() === newRateLimit, + "other block generator did not receive rate update") } } finally { tracker.stop(false) @@ -67,69 +62,73 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } -/** - * An input DStream with a hard-coded receiver that gives access to internals for testing. - * - * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test, - * or otherwise you may get {{{NotSerializableException}}} when trying to serialize - * the receiver. - * @see [[[SingletonDummyReceiver]]]. - */ -private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext) +/** An input DStream with for testing rate controlling */ +private[streaming] class RateTestInputDStream(@transient ssc_ : StreamingContext) extends ReceiverInputDStream[Int](ssc_) { - override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver - - def getCurrentRateLimit: Option[Long] = { - invokeExecutorMethod.getCurrentRateLimit - } + override def getReceiver(): Receiver[Int] = new RateTestReceiver(id) @volatile - var publishCalls = 0 + var publishedRates = 0 override val rateController: Option[RateController] = { - Some(new RateController(id, new ConstantEstimator(100.0)) { + Some(new RateController(id, new ConstantEstimator(100)) { override def publish(rate: Long): Unit = { - publishCalls += 1 + publishedRates += 1 } }) } +} - private def invokeExecutorMethod: ReceiverSupervisor = { - val c = classOf[Receiver[_]] - val ex = c.getDeclaredMethod("executor") - ex.setAccessible(true) - ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor] +/** A receiver implementation for testing rate controlling */ +private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + private lazy val customBlockGenerator = supervisor.createBlockGenerator( + new BlockGeneratorListener { + override def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit = {} + override def onError(message: String, throwable: Throwable): Unit = {} + override def onGenerateBlock(blockId: StreamBlockId): Unit = {} + override def onAddData(data: Any, metadata: Any): Unit = {} + } + ) + + setReceiverId(receiverId) + + override def onStart(): Unit = { + customBlockGenerator + RateTestReceiver.registerReceiver(this) } -} -/** - * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when - * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being - * serialized when receivers are installed on executors. - * - * @note It's necessary to be a top-level object, or else serialization would create another - * one on the executor side and we won't be able to read its rate limit. - */ -private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) { + override def onStop(): Unit = { + RateTestReceiver.deregisterReceiver() + } + + override def preferredLocation: Option[String] = host - /** Reset the object to be usable in another test. */ - def reset(): Unit = { - executor_ = null + def getDefaultBlockGeneratorRateLimit(): Long = { + supervisor.getCurrentRateLimit + } + + def getCustomBlockGeneratorRateLimit(): Long = { + customBlockGenerator.getCurrentLimit } } /** - * Dummy receiver implementation + * A helper object to RateTestReceiver that give access to the currently active RateTestReceiver + * instance. */ -private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) - extends Receiver[Int](StorageLevel.MEMORY_ONLY) { +private[streaming] object RateTestReceiver { + @volatile private var activeReceiver: RateTestReceiver = null - setReceiverId(receiverId) - - override def onStart(): Unit = {} + def registerReceiver(receiver: RateTestReceiver): Unit = { + activeReceiver = receiver + } - override def onStop(): Unit = {} + def deregisterReceiver(): Unit = { + activeReceiver = null + } - override def preferredLocation: Option[String] = host + def getActive(): Option[RateTestReceiver] = Option(activeReceiver) } From 1723e34893f9b087727ea0e5c8b335645f42c295 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Thu, 6 Aug 2015 14:37:25 -0700 Subject: [PATCH 0096/1824] =?UTF-8?q?[DOCS]=20[STREAMING]=20make=20the=20e?= =?UTF-8?q?xisting=20parameter=20docs=20for=20OffsetRange=20ac=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …tually visible Author: cody koeninger Closes #7995 from koeninger/doc-fixes and squashes the following commits: 87af9ea [cody koeninger] [Docs][Streaming] make the existing parameter docs for OffsetRange actually visible --- .../org/apache/spark/streaming/kafka/OffsetRange.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index f326e7f1f6f8..2f8981d4898b 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -42,16 +42,16 @@ trait HasOffsetRanges { * :: Experimental :: * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class * can be created with `OffsetRange.create()`. + * @param topic Kafka topic name + * @param partition Kafka partition id + * @param fromOffset Inclusive starting offset + * @param untilOffset Exclusive ending offset */ @Experimental final class OffsetRange private( - /** Kafka topic name */ val topic: String, - /** Kafka partition id */ val partition: Int, - /** inclusive starting offset */ val fromOffset: Long, - /** exclusive ending offset */ val untilOffset: Long) extends Serializable { import OffsetRange.OffsetRangeTuple From 346209097e88fe79015359e40b49c32cc0bdc439 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 6 Aug 2015 14:39:36 -0700 Subject: [PATCH 0097/1824] [SPARK-9639] [STREAMING] Fix a potential NPE in Streaming JobScheduler Because `JobScheduler.stop(false)` may set `eventLoop` to null when `JobHandler` is running, then it's possible that when `post` is called, `eventLoop` happens to null. This PR fixed this bug and also set threads in `jobExecutor` to `daemon`. Author: zsxwing Closes #7960 from zsxwing/fix-npe and squashes the following commits: b0864c4 [zsxwing] Fix a potential NPE in Streaming JobScheduler --- .../streaming/scheduler/JobScheduler.scala | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 7e735562dca3..6d4cdc4aa6b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.scheduler -import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConversions._ import scala.util.{Failure, Success} @@ -25,7 +25,7 @@ import scala.util.{Failure, Success} import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ -import org.apache.spark.util.EventLoop +import org.apache.spark.util.{EventLoop, ThreadUtils} private[scheduler] sealed trait JobSchedulerEvent @@ -44,7 +44,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // https://gist.github.com/AlainODea/1375759b8720a3f9f094 private val jobSets: java.util.Map[Time, JobSet] = new ConcurrentHashMap[Time, JobSet] private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobExecutor = + ThreadUtils.newDaemonFixedThreadPool(numConcurrentJobs, "streaming-job-executor") private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() @@ -193,14 +194,25 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) try { - eventLoop.post(JobStarted(job)) - // Disable checks for existing output directories in jobs launched by the streaming - // scheduler, since we may need to write output to an existing directory during checkpoint - // recovery; see SPARK-4835 for more details. - PairRDDFunctions.disableOutputSpecValidation.withValue(true) { - job.run() + // We need to assign `eventLoop` to a temp variable. Otherwise, because + // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then + // it's possible that when `post` is called, `eventLoop` happens to null. + var _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobStarted(job)) + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + job.run() + } + _eventLoop = eventLoop + if (_eventLoop != null) { + _eventLoop.post(JobCompleted(job)) + } + } else { + // JobScheduler has been stopped. } - eventLoop.post(JobCompleted(job)) } finally { ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null) ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null) From 3504bf3aa9f7b75c0985f04ce2944833d8c5b5bd Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 6 Aug 2015 15:04:44 -0700 Subject: [PATCH 0098/1824] [SPARK-9630] [SQL] Clean up new aggregate operators (SPARK-9240 follow up) This is the followup of https://github.com/apache/spark/pull/7813. It renames `HybridUnsafeAggregationIterator` to `TungstenAggregationIterator` and makes it only work with `UnsafeRow`. Also, I add a `TungstenAggregate` that uses `TungstenAggregationIterator` and make `SortBasedAggregate` (renamed from `SortBasedAggregate`) only works with `SafeRow`. Author: Yin Huai Closes #7954 from yhuai/agg-followUp and squashes the following commits: 4d2f4fc [Yin Huai] Add comments and free map. 0d7ddb9 [Yin Huai] Add TungstenAggregationQueryWithControlledFallbackSuite to test fall back process. 91d69c2 [Yin Huai] Rename UnsafeHybridAggregationIterator to TungstenAggregateIteraotr and make it only work with UnsafeRow. --- .../expressions/aggregate/functions.scala | 14 +- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../sql/execution/UnsafeRowSerializer.scala | 20 +- .../sql/execution/aggregate/Aggregate.scala | 182 ----- .../aggregate/SortBasedAggregate.scala | 103 +++ .../SortBasedAggregationIterator.scala | 26 - .../aggregate/TungstenAggregate.scala | 102 +++ .../TungstenAggregationIterator.scala | 667 ++++++++++++++++++ .../UnsafeHybridAggregationIterator.scala | 372 ---------- .../spark/sql/execution/aggregate/utils.scala | 260 +++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../execution/AggregationQuerySuite.scala | 104 ++- 12 files changed, 1192 insertions(+), 663 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 88fb516e64aa..a73024d6adba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -31,8 +31,11 @@ case class Average(child: Expression) extends AlgebraicAggregate { override def dataType: DataType = resultType // Expected input data type. - // TODO: Once we remove the old code path, we can use our analyzer to cast NullType - // to the default data type of the NumericType. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select avg(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) private val resultType = child.dataType match { @@ -256,12 +259,19 @@ case class Sum(child: Expression) extends AlgebraicAggregate { override def dataType: DataType = resultType // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select sum(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) private val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) + // TODO: Remove this line once we remove the NullType from inputTypes. + case NullType => IntegerType case _ => child.dataType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a730ffbb217c..c5aaebe67322 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -191,8 +191,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // aggregate function to the corresponding attribute of the function. val aggregateFunctionMap = aggregateExpressions.map { agg => val aggregateFunction = agg.aggregateFunction + val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute (aggregateFunction, agg.isDistinct) -> - Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction -> attribtue) }.toMap val (functionsWithDistinct, functionsWithoutDistinct) = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 16498da080c8..39f8f992a9f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream} +import java.io._ import java.nio.ByteBuffer import scala.reflect.ClassTag @@ -58,11 +58,26 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) + // When `out` is backed by ChainedBufferOutputStream, we will get an + // UnsupportedOperationException when we call dOut.writeInt because it internally calls + // ChainedBufferOutputStream's write(b: Int), which is not supported. + // To workaround this issue, we create an array for sorting the int value. + // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and + // run SparkSqlSerializer2SortMergeShuffleSuite. + private[this] var intBuffer: Array[Byte] = new Array[Byte](4) private[this] val dOut: DataOutputStream = new DataOutputStream(out) override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] - dOut.writeInt(row.getSizeInBytes) + val size = row.getSizeInBytes + // This part is based on DataOutputStream's writeInt. + // It is for dOut.writeInt(row.getSizeInBytes). + intBuffer(0) = ((size >>> 24) & 0xFF).toByte + intBuffer(1) = ((size >>> 16) & 0xFF).toByte + intBuffer(2) = ((size >>> 8) & 0xFF).toByte + intBuffer(3) = ((size >>> 0) & 0xFF).toByte + dOut.write(intBuffer, 0, 4) + row.writeToStream(out, writeBuffer) this } @@ -90,6 +105,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null + intBuffer = null dOut.writeInt(EOF) dOut.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala deleted file mode 100644 index cf568dc04867..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.StructType - -/** - * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types - * of the grouping expressions and aggregate functions, it determines if it uses - * sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to - * process input rows. - */ -case class Aggregate( - requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - private[this] val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - private[this] val hasNonAlgebricAggregateFunctions = - !allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) - - // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of - // grouping key and aggregation buffer is supported; and (3) all - // aggregate functions are algebraic. - private[this] val supportsHybridIterator: Boolean = { - val aggregationBufferSchema: StructType = - StructType.fromAttributes( - allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) - val groupKeySchema: StructType = - StructType.fromAttributes(groupingExpressions.map(_.toAttribute)) - - val schemaSupportsUnsafe: Boolean = - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupKeySchema) - - // TODO: Use the hybrid iterator for non-algebric aggregate functions. - sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions - } - - // We need to use sorted input if we have grouping expressions, and - // we cannot use the hybrid iterator or the hybrid is disabled. - private[this] val requiresSortedInput: Boolean = { - groupingExpressions.nonEmpty && !supportsHybridIterator - } - - override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions - - // If result expressions' data types are all fixed length, we generate unsafe rows - // (We have this requirement instead of check the result of UnsafeProjection.canSupport - // is because we use a mutable projection to generate the result). - override def outputsUnsafeRows: Boolean = { - // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength) - // TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix - // any issue we get. - false - } - - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - if (requiresSortedInput) { - // TODO: We should not sort the input rows if they are just in reversed order. - groupingExpressions.map(SortOrder(_, Ascending)) :: Nil - } else { - Seq.fill(children.size)(Nil) - } - } - - override def outputOrdering: Seq[SortOrder] = { - if (requiresSortedInput) { - // It is possible that the child.outputOrdering starts with the required - // ordering expressions (e.g. we require [a] as the sort expression and the - // child's outputOrdering is [a, b]). We can only guarantee the output rows - // are sorted by values of groupingExpressions. - groupingExpressions.map(SortOrder(_, Ascending)) - } else { - Nil - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - child.execute().mapPartitions { iter => - // Because the constructor of an aggregation iterator will read at least the first row, - // we need to get the value of iter.hasNext first. - val hasInput = iter.hasNext - val useHybridIterator = - hasInput && - supportsHybridIterator && - groupingExpressions.nonEmpty - if (useHybridIterator) { - UnsafeHybridAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _, - child.output, - iter, - outputsUnsafeRows) - } else { - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator[InternalRow]() - } else { - val outputIter = SortBasedAggregationIterator.createFromInputIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection _ , - newProjection _, - child.output, - iter, - outputsUnsafeRows) - if (!hasInput && groupingExpressions.isEmpty) { - // There is no input and there is no grouping expressions. - // We need to output a single row as the output. - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) - } else { - outputIter - } - } - } - } - } - - override def simpleString: String = { - val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) { - classOf[UnsafeHybridAggregationIterator].getSimpleName - } else { - classOf[SortBasedAggregationIterator].getSimpleName - } - - s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}""" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala new file mode 100644 index 000000000000..ad428ad663f3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.types.StructType + +case class SortBasedAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = false + + override def canProcessUnsafeRows: Boolean = false + + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + groupingExpressions.map(SortOrder(_, Ascending)) + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + // Because the constructor of an aggregation iterator will read at least the first row, + // we need to get the value of iter.hasNext first. + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[InternalRow]() + } else { + val outputIter = SortBasedAggregationIterator.createFromInputIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection _, + newProjection _, + child.output, + iter, + outputsUnsafeRows) + if (!hasInput && groupingExpressions.isEmpty) { + // There is no input and there is no grouping expressions. + // We need to output a single row as the output. + Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + } else { + outputIter + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 40f6bff53d2b..67ebafde25ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -204,31 +204,5 @@ object SortBasedAggregationIterator { newMutableProjection, outputsUnsafeRows) } - - def createFromKVIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { - new SortBasedAggregationIterator( - groupingKeyAttributes, - valueAttributes, - inputKVIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala new file mode 100644 index 000000000000..5a0b4d47d62f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} + +case class TungstenAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression2], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + + override def canProcessUnsafeRows: Boolean = true + + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + // This is for testing. We force TungstenAggregationIterator to fall back to sort-based + // aggregation once it has processed a given number of input rows. + private val testFallbackStartsAt: Option[Int] = { + sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { + case null | "" => None + case fallbackStartsAt => Some(fallbackStartsAt.toInt) + } + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] + } else { + val aggregationIterator = + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + iter.asInstanceOf[Iterator[UnsafeRow]], + testFallbackStartsAt) + + if (!hasInput && groupingExpressions.isEmpty) { + Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) + } else { + aggregationIterator + } + } + } + } + + override def simpleString: String = { + val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + + testFallbackStartsAt match { + case None => s"TungstenAggregate ${groupingExpressions} ${allAggregateExpressions}" + case Some(fallbackStartsAt) => + s"TungstenAggregateWithControlledFallback ${groupingExpressions} " + + s"${allAggregateExpressions} fallbackStartsAt=$fallbackStartsAt" + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala new file mode 100644 index 000000000000..b9d44aace100 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.types.StructType + +/** + * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. + * + * This iterator first uses hash-based aggregation to process input rows. It uses + * a hash map to store groups and their corresponding aggregation buffers. If we + * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]], + * it switches to sort-based aggregation. The process of the switch has the following step: + * - Step 1: Sort all entries of the hash map based on values of grouping expressions and + * spill them to disk. + * - Step 2: Create a external sorter based on the spilled sorted map entries. + * - Step 3: Redirect all input rows to the external sorter. + * - Step 4: Get a sorted [[KVIterator]] from the external sorter. + * - Step 5: Initialize sort-based aggregation. + * Then, this iterator works in the way of sort-based aggregation. + * + * The code of this class is organized as follows: + * - Part 1: Initializing aggregate functions. + * - Part 2: Methods and fields used by setting aggregation buffer values, + * processing input rows from inputIter, and generating output + * rows. + * - Part 3: Methods and fields used by hash-based aggregation. + * - Part 4: The function used to switch this iterator from hash-based + * aggregation to sort-based aggregation. + * - Part 5: Methods and fields used by sort-based aggregation. + * - Part 6: Loads input and process input rows. + * - Part 7: Public methods of this iterator. + * - Part 8: A utility function used to generate a result when there is no + * input and there is no grouping expression. + * + * @param groupingExpressions + * expressions for grouping keys + * @param nonCompleteAggregateExpressions + * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. + * @param completeAggregateExpressions + * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * @param initialInputBufferOffset + * If this iterator is used to handle functions with mode [[PartialMerge]] or [[Final]]. + * The input rows have the format of `grouping keys + aggregation buffer`. + * This offset indicates the starting position of aggregation buffer in a input row. + * @param resultExpressions + * expressions for generating output rows. + * @param newMutableProjection + * the function used to create mutable projections. + * @param originalInputAttributes + * attributes of representing input rows from `inputIter`. + * @param inputIter + * the iterator containing input [[UnsafeRow]]s. + */ +class TungstenAggregationIterator( + groupingExpressions: Seq[NamedExpression], + nonCompleteAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression2], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + originalInputAttributes: Seq[Attribute], + inputIter: Iterator[UnsafeRow], + testFallbackStartsAt: Option[Int]) + extends Iterator[UnsafeRow] with Logging { + + /////////////////////////////////////////////////////////////////////////// + // Part 1: Initializing aggregate functions. + /////////////////////////////////////////////////////////////////////////// + + // A Seq containing all AggregateExpressions. + // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final + // are at the beginning of the allAggregateExpressions. + private[this] val allAggregateExpressions: Seq[AggregateExpression2] = + nonCompleteAggregateExpressions ++ completeAggregateExpressions + + // Check to make sure we do not have more than three modes in our AggregateExpressions. + // If we have, users are hitting a bug and we throw an IllegalStateException. + if (allAggregateExpressions.map(_.mode).distinct.length > 2) { + throw new IllegalStateException( + s"$allAggregateExpressions should have no more than 2 kinds of modes.") + } + + // + // The modes of AggregateExpressions. Right now, we can handle the following mode: + // - Partial-only: + // All AggregateExpressions have the mode of Partial. + // For this case, aggregationMode is (Some(Partial), None). + // - PartialMerge-only: + // All AggregateExpressions have the mode of PartialMerge). + // For this case, aggregationMode is (Some(PartialMerge), None). + // - Final-only: + // All AggregateExpressions have the mode of Final. + // For this case, aggregationMode is (Some(Final), None). + // - Final-Complete: + // Some AggregateExpressions have the mode of Final and + // others have the mode of Complete. For this case, + // aggregationMode is (Some(Final), Some(Complete)). + // - Complete-only: + // nonCompleteAggregateExpressions is empty and we have AggregateExpressions + // with mode Complete in completeAggregateExpressions. For this case, + // aggregationMode is (None, Some(Complete)). + // - Grouping-only: + // There is no AggregateExpression. For this case, AggregationMode is (None,None). + // + private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = { + nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> + completeAggregateExpressions.map(_.mode).distinct.headOption + } + + // All aggregate functions. TungstenAggregationIterator only handles AlgebraicAggregates. + // If there is any functions that is not an AlgebraicAggregate, we throw an + // IllegalStateException. + private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = { + if (!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])) { + throw new IllegalStateException( + "Only AlgebraicAggregates should be passed in TungstenAggregationIterator.") + } + + allAggregateExpressions + .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate]) + .toArray + } + + /////////////////////////////////////////////////////////////////////////// + // Part 2: Methods and fields used by setting aggregation buffer values, + // processing input rows from inputIter, and generating output + // rows. + /////////////////////////////////////////////////////////////////////////// + + // The projection used to initialize buffer values. + private[this] val algebraicInitialProjection: MutableProjection = { + val initExpressions = allAggregateFunctions.flatMap(_.initialValues) + newMutableProjection(initExpressions, Nil)() + } + + // Creates a new aggregation buffer and initializes buffer values. + // This functions should be only called at most three times (when we create the hash map, + // when we switch to sort-based aggregation, and when we create the re-used buffer for + // sort-based aggregation). + private def createNewAggregationBuffer(): UnsafeRow = { + val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferRowSize: Int = bufferSchema.length + + val genericMutableBuffer = new GenericMutableRow(bufferRowSize) + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + val buffer = unsafeProjection.apply(genericMutableBuffer) + algebraicInitialProjection.target(buffer)(EmptyRow) + buffer + } + + // Creates a function used to process a row based on the given inputAttributes. + private def generateProcessRow( + inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = { + + val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes) + val inputSchema = StructType.fromAttributes(inputAttributes) + val unsafeRowJoiner = + GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema) + + aggregationMode match { + // Partial-only + case (Some(Partial), None) => + val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions) + val algebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + algebraicUpdateProjection.target(currentBuffer) + algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // PartialMerge-only or Final-only + case (Some(PartialMerge), None) | (Some(Final), None) => + val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions) + // This projection is used to merge buffer values for all AlgebraicAggregates. + val algebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(currentBuffer) + algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // Final-Complete + case (Some(Final), Some(Complete)) => + val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.take(nonCompleteAggregateExpressions.length) + val completeAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + + val completeOffsetExpressions = + Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + val mergeExpressions = + nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions + val finalAlgebraicMergeProjection = + newMutableProjection( + mergeExpressions, + aggregationBufferAttributes ++ inputAttributes)() + + // We do not touch buffer values of aggregate functions with the Final mode. + val finalOffsetExpressions = + Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + val updateExpressions = + finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions) + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + val input = unsafeRowJoiner.join(currentBuffer, row) + // For all aggregate functions with mode Complete, update the given currentBuffer. + completeAlgebraicUpdateProjection.target(currentBuffer)(input) + + // For all aggregate functions with mode Final, merge buffer values in row to + // currentBuffer. + finalAlgebraicMergeProjection.target(currentBuffer)(input) + } + + // Complete-only + case (None, Some(Complete)) => + val completeAggregateFunctions: Array[AlgebraicAggregate] = + allAggregateFunctions.takeRight(completeAggregateExpressions.length) + + val updateExpressions = + completeAggregateFunctions.flatMap(_.updateExpressions) + val completeAlgebraicUpdateProjection = + newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() + + (currentBuffer: UnsafeRow, row: UnsafeRow) => { + completeAlgebraicUpdateProjection.target(currentBuffer) + // For all aggregate functions with mode Complete, update the given currentBuffer. + completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + } + + // Grouping only. + case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {} + + case other => + throw new IllegalStateException( + s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + } + } + + // Creates a function used to generate output rows. + private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + aggregationMode match { + // Partial-only or PartialMerge-only: every output row is basically the values of + // the grouping expressions and the corresponding aggregation buffer. + case (Some(Partial), None) | (Some(PartialMerge), None) => + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer) + } + + // Final-only, Complete-only and Final-Complete: a output row is generated based on + // resultExpressions. + case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer)) + } + + // Grouping-only: a output row is generated from values of grouping expressions. + case (None, None) => + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { + resultProjection(currentGroupingKey) + } + + case other => + throw new IllegalStateException( + s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + } + } + + // An UnsafeProjection used to extract grouping keys from the input rows. + private[this] val groupProjection = + UnsafeProjection.create(groupingExpressions, originalInputAttributes) + + // A function used to process a input row. Its first argument is the aggregation buffer + // and the second argument is the input row. + private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit = + generateProcessRow(originalInputAttributes) + + // A function used to generate output rows based on the grouping keys (first argument) + // and the corresponding aggregation buffer (second argument). + private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow = + generateResultProjection() + + // An aggregation buffer containing initial buffer values. It is used to + // initialize other aggregation buffers. + private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + /////////////////////////////////////////////////////////////////////////// + // Part 3: Methods and fields used by hash-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // This is the hash map used for hash-based aggregation. It is backed by an + // UnsafeFixedWidthAggregationMap and it is used to store + // all groups and their corresponding aggregation buffers for hash-based aggregation. + private[this] val hashMap = new UnsafeFixedWidthAggregationMap( + initialAggregationBuffer, + StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), + StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), + TaskContext.get.taskMemoryManager(), + SparkEnv.get.shuffleMemoryManager, + 1024 * 16, // initial capacity + SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), + false // disable tracking of performance metrics + ) + + // The function used to read and process input rows. When processing input rows, + // it first uses hash-based aggregation by putting groups and their buffers in + // hashMap. If we could not allocate more memory for the map, we switch to + // sort-based aggregation (by calling switchToSortBasedAggregation). + private def processInputs(): Unit = { + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey) + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + } + } + + // This function is only used for testing. It basically the same as processInputs except + // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have + // been processed. + private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + var i = 0 + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = if (i < fallbackStartsAt) { + hashMap.getAggregationBuffer(groupingKey) + } else { + null + } + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + i += 1 + } + } + + // The iterator created from hashMap. It is used to generate output rows when we + // are using hash-based aggregation. + private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, UnsafeRow] = null + + // Indicates if aggregationBufferMapIterator still has key-value pairs. + private[this] var mapIteratorHasNext: Boolean = false + + /////////////////////////////////////////////////////////////////////////// + // Part 4: The function used to switch this iterator from hash-based + // aggregation to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { + logInfo("falling back to sort based aggregation.") + // Step 1: Get the ExternalSorter containing sorted entries of the map. + val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter() + + // Step 2: Free the memory used by the map. + hashMap.free() + + // Step 3: If we have aggregate function with mode Partial or Complete, + // we need to process input rows to get aggregation buffer. + // So, later in the sort-based aggregation iterator, we can do merge. + // If aggregate functions are with mode Final and PartialMerge, + // we just need to project the aggregation buffer from an input row. + val needsProcess = aggregationMode match { + case (Some(Partial), None) => true + case (None, Some(Complete)) => true + case (Some(Final), Some(Complete)) => true + case _ => false + } + + if (needsProcess) { + // First, we create a buffer. + val buffer = createNewAggregationBuffer() + + // Process firstKey and firstInput. + // Initialize buffer. + buffer.copyFrom(initialAggregationBuffer) + processRow(buffer, firstInput) + externalSorter.insertKV(firstKey, buffer) + + // Process the rest of input rows. + while (inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + buffer.copyFrom(initialAggregationBuffer) + processRow(buffer, newInput) + externalSorter.insertKV(groupingKey, buffer) + } + } else { + // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. + // We need to project the aggregation buffer part from an input row. + val buffer = createNewAggregationBuffer() + // The originalInputAttributes are using cloneBufferAttributes. So, we need to use + // allAggregateFunctions.flatMap(_.cloneBufferAttributes). + val bufferExtractor = newMutableProjection( + allAggregateFunctions.flatMap(_.cloneBufferAttributes), + originalInputAttributes)() + bufferExtractor.target(buffer) + + // Insert firstKey and its buffer. + bufferExtractor(firstInput) + externalSorter.insertKV(firstKey, buffer) + + // Insert the rest of input rows. + while (inputIter.hasNext) { + val newInput = inputIter.next() + val groupingKey = groupProjection.apply(newInput) + bufferExtractor(newInput) + externalSorter.insertKV(groupingKey, buffer) + } + } + + // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. + val newAggregationMode = aggregationMode match { + case (Some(Partial), None) => (Some(PartialMerge), None) + case (None, Some(Complete)) => (Some(Final), None) + case (Some(Final), Some(Complete)) => (Some(Final), None) + case other => other + } + aggregationMode = newAggregationMode + + // Basically the value of the KVIterator returned by externalSorter + // will just aggregation buffer. At here, we use cloneBufferAttributes. + val newInputAttributes: Seq[Attribute] = + allAggregateFunctions.flatMap(_.cloneBufferAttributes) + + // Set up new processRow and generateOutput. + processRow = generateProcessRow(newInputAttributes) + generateOutput = generateResultProjection() + + // Step 5: Get the sorted iterator from the externalSorter. + sortedKVIterator = externalSorter.sortedIterator() + + // Step 6: Pre-load the first key-value pair from the sorted iterator to make + // hasNext idempotent. + sortedInputHasNewGroup = sortedKVIterator.next() + + // Copy the first key and value (aggregation buffer). + if (sortedInputHasNewGroup) { + val key = sortedKVIterator.getKey + val value = sortedKVIterator.getValue + nextGroupingKey = key.copy() + currentGroupingKey = key.copy() + firstRowInNextGroup = value.copy() + } + + // Step 7: set sortBased to true. + sortBased = true + } + + /////////////////////////////////////////////////////////////////////////// + // Part 5: Methods and fields used by sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Indicates if we are using sort-based aggregation. Because we first try to use + // hash-based aggregation, its initial value is false. + private[this] var sortBased: Boolean = false + + // The KVIterator containing input rows for the sort-based aggregation. It will be + // set in switchToSortBasedAggregation when we switch to sort-based aggregation. + private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = null + + // The grouping key of the current group. + private[this] var currentGroupingKey: UnsafeRow = null + + // The grouping key of next group. + private[this] var nextGroupingKey: UnsafeRow = null + + // The first row of next group. + private[this] var firstRowInNextGroup: UnsafeRow = null + + // Indicates if we has new group of rows from the sorted input iterator. + private[this] var sortedInputHasNewGroup: Boolean = false + + // The aggregation buffer used by the sort-based aggregation. + private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + + // Processes rows in the current group. It will stop when it find a new group. + private def processCurrentSortedGroup(): Unit = { + // First, we need to copy nextGroupingKey to currentGroupingKey. + currentGroupingKey.copyFrom(nextGroupingKey) + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + + // The search will stop when we see the next group or there is no + // input row left in the iter. + // Pre-load the first key-value pair to make the condition of the while loop + // has no action (we do not trigger loading a new key-value pair + // when we evaluate the condition). + var hasNext = sortedKVIterator.next() + while (!findNextPartition && hasNext) { + // Get the grouping key and value (aggregation buffer). + val groupingKey = sortedKVIterator.getKey + val inputAggregationBuffer = sortedKVIterator.getValue + + // Check if the current row belongs the current input row. + if (currentGroupingKey.equals(groupingKey)) { + processRow(sortBasedAggregationBuffer, inputAggregationBuffer) + + hasNext = sortedKVIterator.next() + } else { + // We find a new group. + findNextPartition = true + // copyFrom will fail when + nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy() + firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy() + + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the sortedKVIterator. + if (!findNextPartition) { + sortedInputHasNewGroup = false + sortedKVIterator.close() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 6: Loads input rows and setup aggregationBufferMapIterator if we + // have not switched to sort-based aggregation. + /////////////////////////////////////////////////////////////////////////// + + // Starts to process input rows. + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } + + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } + } + + /////////////////////////////////////////////////////////////////////////// + // Par 7: Iterator's public methods. + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = { + (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext) + } + + override final def next(): UnsafeRow = { + if (hasNext) { + if (sortBased) { + // Process the current group. + processCurrentSortedGroup() + // Generate output row for the current group. + val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) + // Initialize buffer values for the next group. + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + + outputRow + } else { + // We did not fall back to sort-based aggregation. + val result = + generateOutput( + aggregationBufferMapIterator.getKey, + aggregationBufferMapIterator.getValue) + + // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext + // idempotent. + mapIteratorHasNext = aggregationBufferMapIterator.next() + + if (!mapIteratorHasNext) { + // If there is no input from aggregationBufferMapIterator, we copy current result. + val resultCopy = result.copy() + // Then, we free the map. + hashMap.free() + + resultCopy + } else { + result + } + } + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Part 8: A utility function used to generate a output row when there is no + // input and there is no grouping expression. + /////////////////////////////////////////////////////////////////////////// + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + if (groupingExpressions.isEmpty) { + sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) + // We create a output row and copy it. So, we can free the map. + val resultCopy = + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() + hashMap.free() + resultCopy + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala deleted file mode 100644 index b465787fe8cb..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} -import org.apache.spark.sql.types.StructType - -/** - * An iterator used to evaluate [[AggregateFunction2]]. - * It first tries to use in-memory hash-based aggregation. If we cannot allocate more - * space for the hash map, we spill the sorted map entries, free the map, and then - * switch to sort-based aggregation. - */ -class UnsafeHybridAggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[UnsafeRow, InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends AggregationIterator( - groupingKeyAttributes, - valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - require(groupingKeyAttributes.nonEmpty) - - /////////////////////////////////////////////////////////////////////////// - // Unsafe Aggregation buffers - /////////////////////////////////////////////////////////////////////////// - - // This is the Unsafe Aggregation Map used to store all buffers. - private[this] val buffers = new UnsafeFixedWidthAggregationMap( - newBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)), - StructType.fromAttributes(groupingKeyAttributes), - TaskContext.get.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, - 1024 * 16, // initial capacity - SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), - false // disable tracking of performance metrics - ) - - override protected def newBuffer: UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - val bufferRowSize: Int = bufferSchema.length - - val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val unsafeProjection = - UnsafeProjection.create(bufferSchema.map(_.dataType)) - val buffer = unsafeProjection.apply(genericMutableBuffer) - initializeBuffer(buffer) - buffer - } - - /////////////////////////////////////////////////////////////////////////// - // Methods and variables related to switching to sort-based aggregation - /////////////////////////////////////////////////////////////////////////// - private[this] var sortBased = false - - private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator = _ - - // The value part of the input KV iterator is used to store original input values of - // aggregate functions, we need to convert them to aggregation buffers. - private def processOriginalInput( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val buffer: UnsafeRow = newBuffer - - override def next(): Boolean = { - initializeBuffer(buffer) - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - processRow(buffer, firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - val value = inputKVIterator.getValue() - processRow(buffer, value) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - buffer - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - // The value of the input KV Iterator has the format of groupingExprs + aggregation buffer. - // We need to project the aggregation buffer out. - private def projectInputBufferToUnsafe( - firstKey: UnsafeRow, - firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = { - new KVIterator[UnsafeRow, UnsafeRow] { - private[this] var isFirstRow = true - - private[this] var groupingKey: UnsafeRow = _ - - private[this] val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes) - - private[this] val value: UnsafeRow = { - val genericMutableRow = new GenericMutableRow(bufferSchema.length) - UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow) - } - - private[this] val projectInputBuffer = { - newMutableProjection(bufferSchema, valueAttributes)().target(value) - } - - override def next(): Boolean = { - if (isFirstRow) { - isFirstRow = false - groupingKey = firstKey - projectInputBuffer(firstValue) - - true - } else if (inputKVIterator.next()) { - groupingKey = inputKVIterator.getKey() - projectInputBuffer(inputKVIterator.getValue()) - - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): UnsafeRow = { - value - } - - override def close(): Unit = { - // Do nothing. - } - } - } - - /** - * We need to fall back to sort based aggregation because we do not have enough memory - * for our in-memory hash map (i.e. `buffers`). - */ - private def switchToSortBasedAggregation( - currentGroupingKey: UnsafeRow, - currentRow: InternalRow): Unit = { - logInfo("falling back to sort based aggregation.") - - // Step 1: Get the ExternalSorter containing entries of the map. - val externalSorter = buffers.destructAndCreateExternalSorter() - - // Step 2: Free the memory used by the map. - buffers.free() - - // Step 3: If we have aggregate function with mode Partial or Complete, - // we need to process them to get aggregation buffer. - // So, later in the sort-based aggregation iterator, we can do merge. - // If aggregate functions are with mode Final and PartialMerge, - // we just need to project the aggregation buffer from the input. - val needsProcess = aggregationMode match { - case (Some(Partial), None) => true - case (None, Some(Complete)) => true - case (Some(Final), Some(Complete)) => true - case _ => false - } - - val processedIterator = if (needsProcess) { - processOriginalInput(currentGroupingKey, currentRow) - } else { - // The input value's format is groupingExprs + buffer. - // We need to project the buffer part out. - projectInputBufferToUnsafe(currentGroupingKey, currentRow) - } - - // Step 4: Redirect processedIterator to externalSorter. - while (processedIterator.next()) { - externalSorter.insertKV(processedIterator.getKey(), processedIterator.getValue()) - } - - // Step 5: Get the sorted iterator from the externalSorter. - val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator() - - // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator. - // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator - // will be PartialMerge. For a aggregate function with mode Complete, - // its mode in the SortBasedAggregationIterator will be Final. - val newNonCompleteAggregateExpressions = allAggregateExpressions.map { - case AggregateExpression2(func, Partial, isDistinct) => - AggregateExpression2(func, PartialMerge, isDistinct) - case AggregateExpression2(func, Complete, isDistinct) => - AggregateExpression2(func, Final, isDistinct) - case other => other - } - val newNonCompleteAggregateAttributes = - nonCompleteAggregateAttributes ++ completeAggregateAttributes - - val newValueAttributes = - allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) - - sortBasedAggregationIterator = SortBasedAggregationIterator.createFromKVIterator( - groupingKeyAttributes = groupingKeyAttributes, - valueAttributes = newValueAttributes, - inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, InternalRow]], - nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions, - nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = resultExpressions, - newMutableProjection = newMutableProjection, - outputsUnsafeRows = outputsUnsafeRows) - } - - /////////////////////////////////////////////////////////////////////////// - // Methods used to initialize this iterator. - /////////////////////////////////////////////////////////////////////////// - - /** Starts to read input rows and falls back to sort-based aggregation if necessary. */ - protected def initialize(): Unit = { - var hasNext = inputKVIterator.next() - while (!sortBased && hasNext) { - val groupingKey = inputKVIterator.getKey() - val currentRow = inputKVIterator.getValue() - val buffer = buffers.getAggregationBuffer(groupingKey) - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, currentRow) - sortBased = true - } else { - processRow(buffer, currentRow) - hasNext = inputKVIterator.next() - } - } - } - - // This is the starting point of this iterator. - initialize() - - // Creates the iterator for the Hash Aggregation Map after we have populated - // contents of that map. - private[this] val aggregationBufferMapIterator = buffers.iterator() - - private[this] var _mapIteratorHasNext = false - - // Pre-load the first key-value pair from the map to make hasNext idempotent. - if (!sortBased) { - _mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!_mapIteratorHasNext) { - buffers.free() - } - } - - /////////////////////////////////////////////////////////////////////////// - // Iterator's public methods - /////////////////////////////////////////////////////////////////////////// - - override final def hasNext: Boolean = { - (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && _mapIteratorHasNext) - } - - - override final def next(): InternalRow = { - if (hasNext) { - if (sortBased) { - sortBasedAggregationIterator.next() - } else { - // We did not fall back to the sort-based aggregation. - val result = - generateOutput( - aggregationBufferMapIterator.getKey, - aggregationBufferMapIterator.getValue) - // Pre-load next key-value pair form aggregationBufferMapIterator. - _mapIteratorHasNext = aggregationBufferMapIterator.next() - - if (!_mapIteratorHasNext) { - val resultCopy = result.copy() - buffers.free() - resultCopy - } else { - result - } - } - } else { - // no more result - throw new NoSuchElementException - } - } -} - -object UnsafeHybridAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = { - new UnsafeHybridAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, inputIter), - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows) - } - // scalastyle:on -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 960be08f84d9..80816a095ea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -17,20 +17,41 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan} +import org.apache.spark.sql.types.StructType /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { + def supportsTungstenAggregate( + groupingExpressions: Seq[Expression], + aggregateBufferAttributes: Seq[Attribute]): Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeProjection.canSupport(groupingExpressions) + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + // Check if we can use TungstenAggregate. + val usesTungstenAggregate = + child.sqlContext.conf.unsafeEnabled && + aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + supportsTungstenAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + + // 1. Create an Aggregate Operator for partial aggregations. val namedGroupingExpressions = groupingExpressions.map { case ne: NamedExpression => ne -> ne @@ -44,11 +65,23 @@ object Utils { val groupExpressionMap = namedGroupingExpressions.toMap val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } - val partialAggregate = - Aggregate( + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) + val partialResultExpressions = + namedGroupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + + val partialAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = namedGroupingExpressions.map(_._2), + nonCompleteAggregateExpressions = partialAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialResultExpressions, + child = child) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], groupingExpressions = namedGroupingExpressions.map(_._2), nonCompleteAggregateExpressions = partialAggregateExpressions, @@ -56,29 +89,57 @@ object Utils { completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, initialInputBufferOffset = 0, - resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes, + resultExpressions = partialResultExpressions, child = child) + } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 } - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - val finalAggregate = - Aggregate( + + val finalAggregate = if (usesTungstenAggregate) { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + // aggregateFunctionMap contains unique aggregate functions. + val aggregateFunction = + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1 + aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = namedGroupingAttributes.length, + resultExpressions = rewrittenResultExpressions, + child = partialAggregate) + } else { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, @@ -88,6 +149,7 @@ object Utils { initialInputBufferOffset = namedGroupingAttributes.length, resultExpressions = rewrittenResultExpressions, child = partialAggregate) + } finalAggregate :: Nil } @@ -96,10 +158,18 @@ object Utils { groupingExpressions: Seq[Expression], functionsWithDistinct: Seq[AggregateExpression2], functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct + val usesTungstenAggregate = + child.sqlContext.conf.unsafeEnabled && + aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + supportsTungstenAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + // 1. Create an Aggregate Operator for partial aggregations. // The grouping expressions are original groupingExpressions and // distinct columns. For example, for avg(distinct value) ... group by key @@ -129,19 +199,26 @@ object Utils { val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) - val partialAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Partial, false) - } - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } + val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) val partialAggregateGroupingExpressions = (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) val partialAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes - val partialAggregate = - Aggregate( + namedGroupingAttributes ++ + distinctColumnAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + val partialAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, @@ -151,20 +228,27 @@ object Utils { initialInputBufferOffset = 0, resultExpressions = partialAggregateResult, child = child) + } // 2. Create an Aggregate Operator for partial merge aggregations. - val partialMergeAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, PartialMerge, false) - } + val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) val partialMergeAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes - val partialMergeAggregate = - Aggregate( + namedGroupingAttributes ++ + distinctColumnAttributes ++ + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + val partialMergeAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, @@ -174,48 +258,91 @@ object Utils { initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) + } // 3. Create an Aggregate Operator for partial merge aggregations. - val finalAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Final, false) - } + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 } + // Create a map to store those rewritten aggregate functions. We always need to use + // both function and its corresponding isDistinct flag as the key because function itself + // does not knows if it is has distinct keyword or now. + val rewrittenAggregateFunctions = + mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2] val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) => + case agg @ AggregateExpression2(aggregateFunction, mode, true) => val rewrittenAggregateFunction = aggregateFunction.transformDown { case expr if distinctColumnExpressionMap.contains(expr) => distinctColumnExpressionMap(expr).toAttribute }.asInstanceOf[AggregateFunction2] + // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions + // to track the old version and the new version of this function. + rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, false) + AggregateExpression2(rewrittenAggregateFunction, Complete, true) - val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct) + val aggregateFunctionAttribute = + aggregateFunctionMap(agg.aggregateFunction, true)._2 (rewrittenAggregateExpression -> aggregateFunctionAttribute) }.unzip - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - val finalAndCompleteAggregate = - Aggregate( + val finalAndCompleteAggregate = if (usesTungstenAggregate) { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + val function = agg.aggregateFunction + val isDistinct = agg.isDistinct + val aggregateFunction = + if (rewrittenAggregateFunctions.contains(function, isDistinct)) { + // If this function has been rewritten, we get the rewritten version from + // rewrittenAggregateFunctions. + rewrittenAggregateFunctions(function, isDistinct) + } else { + // Oterwise, we get it from aggregateFunctionMap, which contains unique + // aggregate functions that have not been rewritten. + aggregateFunctionMap(function, isDistinct)._1 + } + aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + completeAggregateExpressions = completeAggregateExpressions, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = rewrittenResultExpressions, + child = partialMergeAggregate) + } else { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, @@ -225,6 +352,7 @@ object Utils { initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = rewrittenResultExpressions, child = partialMergeAggregate) + } finalAndCompleteAggregate :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cef40dd324d9..c64aa7a07dc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -262,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.executedPlan - .collect { case _: aggregate.Aggregate => true } + .collect { case _: aggregate.TungstenAggregate => true } .nonEmpty if (!hasGeneratedAgg) { fail( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4b35c8fd8353..7b5aa4763fd9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -21,9 +21,9 @@ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row} +import org.apache.spark.sql._ import org.scalatest.BeforeAndAfterAll -import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { @@ -141,6 +141,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Nil) } + test("null literal") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(null), + | COUNT(null), + | FIRST(null), + | LAST(null), + | MAX(null), + | MIN(null), + | SUM(null) + """.stripMargin), + Row(null, 0, null, null, null, null, null) :: Nil) + } + test("only do grouping") { checkAnswer( sqlContext.sql( @@ -266,13 +282,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT avg(value) FROM agg1 """.stripMargin), Row(11.125) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT avg(null) - """.stripMargin), - Row(null) :: Nil) } test("udaf") { @@ -364,7 +373,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be | max(distinct value1) |FROM agg2 """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100.0)) + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) checkAnswer( sqlContext.sql( @@ -402,6 +411,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: Row(3, null, 3.0, null, null, null) :: Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) } test("test count") { @@ -496,7 +522,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |FROM agg1 |GROUP BY key """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.Aggregate => agg + case agg: aggregate.SortBasedAggregate => agg + case agg: aggregate.TungstenAggregate => agg } val message = "We should fallback to the old aggregation code path if " + @@ -537,3 +564,58 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite { sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } + +class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") + } + + override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { + (0 to 2).foreach { fallbackStartsAt => + sqlContext.setConf( + "spark.sql.TungstenAggregate.testFallbackStartsAt", + fallbackStartsAt.toString) + + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } + } + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } +} From e234ea1b49d30bb6c8b8c001bd98c43de290dcff Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 6 Aug 2015 15:30:27 -0700 Subject: [PATCH 0099/1824] [SPARK-9645] [YARN] [CORE] Allow shuffle service to read shuffle files. Spark should not mess with the permissions of directories created by the cluster manager. Here, by setting the block manager dir permissions to 700, the shuffle service (running as the YARN user) wouldn't be able to serve shuffle files created by applications. Also, the code to protect the local app dir was missing in standalone's Worker; that has been now added. Since all processes run as the same user in standalone, `chmod 700` should not cause problems. Author: Marcelo Vanzin Closes #7966 from vanzin/SPARK-9645 and squashes the following commits: 6e07b31 [Marcelo Vanzin] Protect the app dir in standalone mode. 384ba6a [Marcelo Vanzin] [SPARK-9645] [yarn] [core] Allow shuffle service to read shuffle files. --- .../main/scala/org/apache/spark/deploy/worker/Worker.scala | 4 +++- .../scala/org/apache/spark/storage/DiskBlockManager.scala | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 6792d3310b06..79b1536d9401 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -428,7 +428,9 @@ private[deploy] class Worker( // application finishes. val appLocalDirs = appDirectories.get(appId).getOrElse { Utils.getOrCreateLocalRootDirs(conf).map { dir => - Utils.createDirectory(dir, namePrefix = "executor").getAbsolutePath() + val appDir = Utils.createDirectory(dir, namePrefix = "executor") + Utils.chmod700(appDir) + appDir.getAbsolutePath() }.toSeq } appDirectories(appId) = appLocalDirs diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 5f537692a16c..56a33d5ca7d6 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -133,7 +133,6 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => try { val localDir = Utils.createDirectory(rootDir, "blockmgr") - Utils.chmod700(localDir) logInfo(s"Created local directory at $localDir") Some(localDir) } catch { From 681e3024b6c2fcb54b42180d94d3ba3eed52a2d4 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 6 Aug 2015 23:43:52 +0100 Subject: [PATCH 0100/1824] [SPARK-9633] [BUILD] SBT download locations outdated; need an update Remove 2 defunct SBT download URLs and replace with the 1 known download URL. Also, use https. Follow up on https://github.com/apache/spark/pull/7792 Author: Sean Owen Closes #7956 from srowen/SPARK-9633 and squashes the following commits: caa40bd [Sean Owen] Remove 2 defunct SBT download URLs and replace with the 1 known download URL. Also, use https. --- build/sbt-launch-lib.bash | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 7930a38b9674..615f84839465 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -38,8 +38,7 @@ dlog () { acquire_sbt_jar () { SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties` - URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar - URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar + URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar JAR=build/sbt-launch-${SBT_VERSION}.jar sbt_jar=$JAR @@ -51,12 +50,10 @@ acquire_sbt_jar () { printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\ - (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\ + curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ mv "${JAR_DL}" "${JAR}" elif [ $(command -v wget) ]; then - (wget --quiet ${URL1} -O "${JAR_DL}" ||\ - (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\ + wget --quiet ${URL1} -O "${JAR_DL}" &&\ mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" From baf4587a569b49e39020c04c2785041bdd00789b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 6 Aug 2015 17:03:14 -0700 Subject: [PATCH 0101/1824] [SPARK-9691] [SQL] PySpark SQL rand function treats seed 0 as no seed https://issues.apache.org/jira/browse/SPARK-9691 jkbradley rxin Author: Yin Huai Closes #7999 from yhuai/pythonRand and squashes the following commits: 4187e0c [Yin Huai] Regression test. a985ef9 [Yin Huai] Use "if seed is not None" instead "if seed" because "if seed" returns false when seed is 0. --- python/pyspark/sql/functions.py | 4 ++-- python/pyspark/sql/tests.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b5c6a01f1885..95f46044d324 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -268,7 +268,7 @@ def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. """ sc = SparkContext._active_spark_context - if seed: + if seed is not None: jc = sc._jvm.functions.rand(seed) else: jc = sc._jvm.functions.rand() @@ -280,7 +280,7 @@ def randn(seed=None): """Generates a column with i.i.d. samples from the standard normal distribution. """ sc = SparkContext._active_spark_context - if seed: + if seed is not None: jc = sc._jvm.functions.randn(seed) else: jc = sc._jvm.functions.randn() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ebd3ea8db6a4..1e3444dd9e3b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -629,6 +629,16 @@ def test_rand_functions(self): for row in rndn: assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + # If the specified seed is 0, we should use it. + # https://issues.apache.org/jira/browse/SPARK-9691 + rnd1 = df.select('key', functions.rand(0)).collect() + rnd2 = df.select('key', functions.rand(0)).collect() + self.assertEqual(sorted(rnd1), sorted(rnd2)) + + rndn1 = df.select('key', functions.randn(0)).collect() + rndn2 = df.select('key', functions.randn(0)).collect() + self.assertEqual(sorted(rndn1), sorted(rndn2)) + def test_between_function(self): df = self.sc.parallelize([ Row(a=1, b=2, c=3), From 4e70e8256ce2f45b438642372329eac7b1e9e8cf Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 17:30:31 -0700 Subject: [PATCH 0102/1824] [SPARK-9228] [SQL] use tungsten.enabled in public for both of codegen/unsafe spark.sql.tungsten.enabled will be the default value for both codegen and unsafe, they are kept internally for debug/testing. cc marmbrus rxin Author: Davies Liu Closes #7998 from davies/tungsten and squashes the following commits: c1c16da [Davies Liu] update doc 1a47be1 [Davies Liu] use tungsten.enabled for both of codegen/unsafe --- docs/sql-programming-guide.md | 6 +++--- .../scala/org/apache/spark/sql/SQLConf.scala | 20 ++++++++++++------- .../spark/sql/execution/SparkPlan.scala | 8 +++++++- .../spark/sql/execution/joins/HashJoin.scala | 3 ++- .../sql/execution/joins/HashOuterJoin.scala | 2 +- .../sql/execution/joins/HashSemiJoin.scala | 3 ++- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3ea77e82422f..6c317175d327 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1884,11 +1884,11 @@ that these options will be deprecated in future release as more optimizations ar - spark.sql.codegen + spark.sql.tungsten.enabled true - When true, code will be dynamically generated at runtime for expression evaluation in a specific - query. For some queries with complicated expression this option can lead to significant speed-ups. + When true, use the optimized Tungsten physical execution backend which explicitly manages memory + and dynamically generates bytecode for expression evaluation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f836122b3e0e..ef35c133d9cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -223,14 +223,21 @@ private[spark] object SQLConf { defaultValue = Some(200), doc = "The default number of partitions to use when shuffling data for joins or aggregations.") - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", defaultValue = Some(true), + doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + + "manages memory and dynamically generates bytecode for expression evaluation.") + + val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + defaultValue = Some(true), // use TUNGSTEN_ENABLED as default doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.") + " a specific query.", + isPublic = false) val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(true), - doc = "When true, use the new optimized Tungsten physical execution backend.") + defaultValue = Some(true), // use TUNGSTEN_ENABLED as default + doc = "When true, use the new optimized Tungsten physical execution backend.", + isPublic = false) val DIALECT = stringConf( "spark.sql.dialect", @@ -427,7 +434,6 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ - private[sql] class SQLConf extends Serializable with CatalystConf { import SQLConf._ @@ -474,11 +480,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2f29067f5646..3fff79cd1b28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -55,12 +55,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled will be set by the desserializer after the constructor has run. + // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the + // constructor has run. val codegenEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.codegenEnabled } else { false } + val unsafeEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.unsafeEnabled + } else { + false + } /** * Whether the "prepare" method is called. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5e9cd9fd2345..22d46d1c3e3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,7 +44,8 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + (self.codegenEnabled && self.unsafeEnabled + && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 346337e64245..701bd3cd8637 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -67,7 +67,7 @@ trait HashOuterJoin { } protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && joinType != FullOuter + (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 47a7d370f541..82dd6eb7e7ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,7 +33,8 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) + (self.codegenEnabled && self.unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) && UnsafeProjection.canSupport(left.schema) && UnsafeProjection.canSupport(right.schema)) From 0867b23c74a3e6347d718b67ddabff17b468eded Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 6 Aug 2015 17:31:16 -0700 Subject: [PATCH 0103/1824] [SPARK-9650][SQL] Fix quoting behavior on interpolated column names Make sure that `$"column"` is consistent with other methods with respect to backticks. Adds a bunch of tests for various ways of constructing columns. Author: Michael Armbrust Closes #7969 from marmbrus/namesWithDots and squashes the following commits: 53ef3d7 [Michael Armbrust] [SPARK-9650][SQL] Fix quoting behavior on interpolated column names 2bf7a92 [Michael Armbrust] WIP --- .../sql/catalyst/analysis/unresolved.scala | 57 ++++++++++++++++ .../catalyst/plans/logical/LogicalPlan.scala | 42 +----------- .../scala/org/apache/spark/sql/Column.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/ColumnExpressionSuite.scala | 68 +++++++++++++++++++ 5 files changed, 128 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 03da45b09f92..43ee3191935e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.errors import org.apache.spark.sql.catalyst.expressions._ @@ -69,8 +70,64 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un } object UnresolvedAttribute { + /** + * Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.'). + */ def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\.")) + + /** + * Creates an [[UnresolvedAttribute]], from a single quoted string (for example using backticks in + * HiveQL. Since the string is consider quoted, no processing is done on the name. + */ def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) + + /** + * Creates an [[UnresolvedAttribute]] from a string in an embedded language. In this case + * we treat it as a quoted identifier, except for '.', which must be further quoted using + * backticks if it is part of a column name. + */ + def quotedString(name: String): UnresolvedAttribute = + new UnresolvedAttribute(parseAttributeName(name)) + + /** + * Used to split attribute name by dot with backticks rule. + * Backticks must appear in pairs, and the quoted string must be a complete name part, + * which means `ab..c`e.f is not allowed. + * Escape character is not supported now, so we can't use backtick inside name part. + */ + def parseAttributeName(name: String): Seq[String] = { + def e = new AnalysisException(s"syntax error in attribute name: $name") + val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] + val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] + var inBacktick = false + var i = 0 + while (i < name.length) { + val char = name(i) + if (inBacktick) { + if (char == '`') { + inBacktick = false + if (i + 1 < name.length && name(i + 1) != '.') throw e + } else { + tmp += char + } + } else { + if (char == '`') { + if (tmp.nonEmpty) throw e + inBacktick = true + } else if (char == '.') { + if (name(i - 1) == '.' || i == name.length - 1) throw e + nameParts += tmp.mkString + tmp.clear() + } else { + tmp += char + } + } + i += 1 + } + if (inBacktick) throw e + nameParts += tmp.mkString + nameParts.toSeq + } } case class UnresolvedFunction( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 9b52f020093f..c290e6acb361 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -179,47 +179,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(parseAttributeName(name), output, resolver) - } - - /** - * Internal method, used to split attribute name by dot with backticks rule. - * Backticks must appear in pairs, and the quoted string must be a complete name part, - * which means `ab..c`e.f is not allowed. - * Escape character is not supported now, so we can't use backtick inside name part. - */ - private def parseAttributeName(name: String): Seq[String] = { - val e = new AnalysisException(s"syntax error in attribute name: $name") - val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] - val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] - var inBacktick = false - var i = 0 - while (i < name.length) { - val char = name(i) - if (inBacktick) { - if (char == '`') { - inBacktick = false - if (i + 1 < name.length && name(i + 1) != '.') throw e - } else { - tmp += char - } - } else { - if (char == '`') { - if (tmp.nonEmpty) throw e - inBacktick = true - } else if (char == '.') { - if (name(i - 1) == '.' || i == name.length - 1) throw e - nameParts += tmp.mkString - tmp.clear() - } else { - tmp += char - } - } - i += 1 - } - if (inBacktick) throw e - nameParts += tmp.mkString - nameParts.toSeq + resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 75365fbcec75..27bd08484734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -54,7 +54,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { case "*" => UnresolvedStar(None) case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) - case _ => UnresolvedAttribute(name) + case _ => UnresolvedAttribute.quotedString(name) }) /** Creates a column based on the given expression. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 6f8ffb54402a..075c0ea2544b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -343,7 +343,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit class StringToColumn(val sc: StringContext) { def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args : _*)) + new ColumnName(sc.s(args: _*)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index e1b3443d7499..6a09a3b72c08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -32,6 +32,74 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { override def sqlContext(): SQLContext = ctx + test("column names with space") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot") + + checkAnswer( + df.select(df("name with space")), + Row(1) :: Nil) + + checkAnswer( + df.select($"name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(col("name with space")), + Row(1) :: Nil) + + checkAnswer( + df.select("name with space"), + Row(1) :: Nil) + + checkAnswer( + df.select(expr("`name with space`")), + Row(1) :: Nil) + } + + test("column names with dot") { + val df = Seq((1, "a")).toDF("name with space", "name.with.dot").as("a") + + checkAnswer( + df.select(df("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select($"`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(col("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select("`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(expr("`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select(df("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select($"a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(col("a.`name.with.dot`")), + Row("a") :: Nil) + + checkAnswer( + df.select("a.`name.with.dot`"), + Row("a") :: Nil) + + checkAnswer( + df.select(expr("a.`name.with.dot`")), + Row("a") :: Nil) + } + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") From 49b1504fe3733eb36a7fc6317ec19aeba5d46f97 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 17:36:12 -0700 Subject: [PATCH 0104/1824] Revert "[SPARK-9228] [SQL] use tungsten.enabled in public for both of codegen/unsafe" This reverts commit 4e70e8256ce2f45b438642372329eac7b1e9e8cf. --- docs/sql-programming-guide.md | 6 +++--- .../scala/org/apache/spark/sql/SQLConf.scala | 20 +++++++------------ .../spark/sql/execution/SparkPlan.scala | 8 +------- .../spark/sql/execution/joins/HashJoin.scala | 3 +-- .../sql/execution/joins/HashOuterJoin.scala | 2 +- .../sql/execution/joins/HashSemiJoin.scala | 3 +-- 6 files changed, 14 insertions(+), 28 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6c317175d327..3ea77e82422f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1884,11 +1884,11 @@ that these options will be deprecated in future release as more optimizations ar - spark.sql.tungsten.enabled + spark.sql.codegen true - When true, use the optimized Tungsten physical execution backend which explicitly manages memory - and dynamically generates bytecode for expression evaluation. + When true, code will be dynamically generated at runtime for expression evaluation in a specific + query. For some queries with complicated expression this option can lead to significant speed-ups. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ef35c133d9cc..f836122b3e0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -223,21 +223,14 @@ private[spark] object SQLConf { defaultValue = Some(200), doc = "The default number of partitions to use when shuffling data for joins or aggregations.") - val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", - defaultValue = Some(true), - doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + - "manages memory and dynamically generates bytecode for expression evaluation.") - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default + defaultValue = Some(true), doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.", - isPublic = false) + " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default - doc = "When true, use the new optimized Tungsten physical execution backend.", - isPublic = false) + defaultValue = Some(true), + doc = "When true, use the new optimized Tungsten physical execution backend.") val DIALECT = stringConf( "spark.sql.dialect", @@ -434,6 +427,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ + private[sql] class SQLConf extends Serializable with CatalystConf { import SQLConf._ @@ -480,11 +474,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 3fff79cd1b28..2f29067f5646 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -55,18 +55,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the - // constructor has run. + // the value of codegenEnabled will be set by the desserializer after the constructor has run. val codegenEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.codegenEnabled } else { false } - val unsafeEnabled: Boolean = if (sqlContext != null) { - sqlContext.conf.unsafeEnabled - } else { - false - } /** * Whether the "prepare" method is called. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 22d46d1c3e3b..5e9cd9fd2345 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,8 +44,7 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(buildKeys) + (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 701bd3cd8637..346337e64245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -67,7 +67,7 @@ trait HashOuterJoin { } protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter + (self.codegenEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 82dd6eb7e7ed..47a7d370f541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,8 +33,7 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) + (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) && UnsafeProjection.canSupport(left.schema) && UnsafeProjection.canSupport(right.schema)) From b87825310ac87485672868bf6a9ed01d154a3626 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Aug 2015 18:25:38 -0700 Subject: [PATCH 0105/1824] [SPARK-9692] Remove SqlNewHadoopRDD's generated Tuple2 and InterruptibleIterator. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A small performance optimization – we don't need to generate a Tuple2 and then immediately discard the key. We also don't need an extra wrapper from InterruptibleIterator. Author: Reynold Xin Closes #8000 from rxin/SPARK-9692 and squashes the following commits: 1d4d0b3 [Reynold Xin] [SPARK-9692] Remove SqlNewHadoopRDD's generated Tuple2 and InterruptibleIterator. --- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 44 +++++++------------ .../spark/sql/parquet/ParquetRelation.scala | 3 +- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 35e44cb59c1b..6a95e44c57fe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -26,14 +26,12 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -60,18 +58,16 @@ private[spark] class SqlNewHadoopPartition( * and the executor side to the shared Hadoop Configuration. * * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with - * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be - * folded into core. + * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ -private[spark] class SqlNewHadoopRDD[K, V]( +private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], + inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[(K, V)](sc, Nil) + extends RDD[V](sc, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -120,8 +116,8 @@ private[spark] class SqlNewHadoopRDD[K, V]( override def compute( theSplit: SparkPartition, - context: TaskContext): InterruptibleIterator[(K, V)] = { - val iter = new Iterator[(K, V)] { + context: TaskContext): Iterator[V] = { + val iter = new Iterator[V] { val split = theSplit.asInstanceOf[SqlNewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = getConf(isDriverSide = false) @@ -154,17 +150,20 @@ private[spark] class SqlNewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - private var reader = format.createRecordReader( + private[this] var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 + + private[this] var havePair = false + private[this] var finished = false override def hasNext: Boolean = { + if (context.isInterrupted) { + throw new TaskKilledException + } if (!finished && !havePair) { finished = !reader.nextKeyValue if (finished) { @@ -178,7 +177,7 @@ private[spark] class SqlNewHadoopRDD[K, V]( !finished } - override def next(): (K, V) = { + override def next(): V = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -186,7 +185,7 @@ private[spark] class SqlNewHadoopRDD[K, V]( if (!finished) { inputMetrics.incRecordsRead(1) } - (reader.getCurrentKey, reader.getCurrentValue) + reader.getCurrentValue } private def close() { @@ -212,23 +211,14 @@ private[spark] class SqlNewHadoopRDD[K, V]( } } } catch { - case e: Exception => { + case e: Exception => if (!Utils.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } - } } } } - new InterruptibleIterator(context, iter) - } - - /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ - @DeveloperApi - def mapPartitionsWithInputSplit[U: ClassTag]( - f: (InputSplit, Iterator[(K, V)]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = { - new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + iter } override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b4337a48dbd8..29c388c22ef9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -291,7 +291,6 @@ private[sql] class ParquetRelation( initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), inputFormatClass = classOf[ParquetInputFormat[InternalRow]], - keyClass = classOf[Void], valueClass = classOf[InternalRow]) { val cacheMetadata = useMetadataCache @@ -328,7 +327,7 @@ private[sql] class ParquetRelation( new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } } - }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] + }.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] } } From 014a9f9d8c9521180f7a448cc7cc96cc00537d5c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 6 Aug 2015 19:04:57 -0700 Subject: [PATCH 0106/1824] [SPARK-9709] [SQL] Avoid starving unsafe operators that use sort The issue is that a task may run multiple sorts, and the sorts run by the child operator (i.e. parent RDD) may acquire all available memory such that other sorts in the same task do not have enough to proceed. This manifests itself in an `IOException("Unable to acquire X bytes of memory")` thrown by `UnsafeExternalSorter`. The solution is to reserve a page in each sorter in the chain before computing the child operator's (parent RDD's) partitions. This requires us to use a new special RDD that does some preparation before computing the parent's partitions. Author: Andrew Or Closes #8011 from andrewor14/unsafe-starve-memory and squashes the following commits: 35b69a4 [Andrew Or] Simplify test 0b07782 [Andrew Or] Minor: update comments 5d5afdf [Andrew Or] Merge branch 'master' of github.com:apache/spark into unsafe-starve-memory 254032e [Andrew Or] Add tests 234acbd [Andrew Or] Reserve a page in sorter when preparing each partition b889e08 [Andrew Or] MapPartitionsWithPreparationRDD --- .../unsafe/sort/UnsafeExternalSorter.java | 43 ++++++++----- .../apache/spark/rdd/MapPartitionsRDD.scala | 3 + .../rdd/MapPartitionsWithPreparationRDD.scala | 49 +++++++++++++++ .../spark/shuffle/ShuffleMemoryManager.scala | 2 +- .../sort/UnsafeExternalSorterSuite.java | 19 +++++- ...MapPartitionsWithPreparationRDDSuite.scala | 60 +++++++++++++++++++ .../spark/sql/execution/SparkPlan.scala | 2 +- .../org/apache/spark/sql/execution/sort.scala | 28 +++++++-- 8 files changed, 184 insertions(+), 22 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 8f78fc5a4162..4c54ba4bce40 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -138,6 +138,11 @@ private UnsafeExternalSorter( this.inMemSorter = existingInMemorySorter; } + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). + acquireNewPage(); + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). @@ -343,22 +348,32 @@ private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquired < pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquired); - spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquiredAfterSpilling != pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); - } - } - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = pageSizeBytes; - allocatedPages.add(currentPage); + acquireNewPage(); + } + } + } + + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * + * If there is not enough space to allocate the new page, spill all existing ones + * and try again. If there is still not enough space, report error to the caller. + */ + private void acquireNewPage() throws IOException { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = pageSizeBytes; + allocatedPages.add(currentPage); } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index a838aac6e8d1..4312d3a41775 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -21,6 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} +/** + * An RDD that applies the provided function to every partition of the parent RDD. + */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala new file mode 100644 index 000000000000..b475bd8d79f8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, Partitioner, TaskContext} + +/** + * An RDD that applies a user provided function to every partition of the parent RDD, and + * additionally allows the user to prepare each partition before computing the parent partition. + */ +private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag]( + prev: RDD[T], + preparePartition: () => M, + executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner: Option[Partitioner] = { + if (preservesPartitioning) firstParent[T].partitioner else None + } + + override def getPartitions: Array[Partition] = firstParent[T].partitions + + /** + * Prepare a partition before computing it from its parent. + */ + override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + val preparedArgument = preparePartition() + val parentIterator = firstParent[T].iterator(partition, context) + executePartition(context, partition.index, preparedArgument, parentIterator) + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 00c1e078a441..e3d229cc9982 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -124,7 +124,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { } } -private object ShuffleMemoryManager { +private[spark] object ShuffleMemoryManager { /** * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction * of the memory pool and a safety factor since collections can sometimes grow bigger than diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 117745f9a9c0..f5300373d87e 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -340,7 +340,8 @@ public void testPeakMemoryUsed() throws Exception { for (int i = 0; i < numRecordsPerPage * 10; i++) { insertNumber(sorter, i); newPeakMemory = sorter.getPeakMemoryUsedBytes(); - if (i % numRecordsPerPage == 0) { + // The first page is pre-allocated on instantiation + if (i % numRecordsPerPage == 0 && i > 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { @@ -364,5 +365,21 @@ public void testPeakMemoryUsed() throws Exception { } } + @Test + public void testReservePageOnInstantiation() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + try { + assertEquals(1, sorter.getNumberOfAllocatedPages()); + // Inserting a new record doesn't allocate more memory since we already have a page + long peakMemory = sorter.getPeakMemoryUsedBytes(); + insertNumber(sorter, 100); + assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes()); + assertEquals(1, sorter.getNumberOfAllocatedPages()); + } finally { + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + } + } diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala new file mode 100644 index 000000000000..c16930e7d649 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.mutable + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext} + +class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext { + + test("prepare called before parent partition is computed") { + sc = new SparkContext("local", "test") + + // Have the parent partition push a number to the list + val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter => + TestObject.things.append(20) + iter + } + + // Push a different number during the prepare phase + val preparePartition = () => { TestObject.things.append(10) } + + // Push yet another number during the execution phase + val executePartition = ( + taskContext: TaskContext, + partitionIndex: Int, + notUsed: Unit, + parentIterator: Iterator[Int]) => { + TestObject.things.append(30) + TestObject.things.iterator + } + + // Verify that the numbers are pushed in the order expected + val result = { + new MapPartitionsWithPreparationRDD[Int, Int, Unit]( + parent, preparePartition, executePartition).collect() + } + assert(result === Array(10, 20, 30)) + } + +} + +private object TestObject { + val things = new mutable.ListBuffer[Int] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2f29067f5646..490428965a61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -158,7 +158,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ final def prepare(): Unit = { if (prepareCalled.compareAndSet(false, true)) { - doPrepare + doPrepare() children.foreach(_.prepare()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 3192b6ebe907..7f69cdb08aa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.{InternalAccumulator, TaskContext} -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -123,7 +123,12 @@ case class TungstenSort( val schema = child.schema val childOutput = child.output val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") - child.execute().mapPartitions({ iter => + + /** + * Set up the sorter in each partition before computing the parent partition. + * This makes sure our sorter is not starved by other sorters used in the same task. + */ + def preparePartition(): UnsafeExternalRowSorter = { val ordering = newOrdering(sortOrder, childOutput) // The comparator for comparing prefix @@ -143,12 +148,25 @@ case class TungstenSort( if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) - val taskContext = TaskContext.get() + sorter + } + + /** Compute a partition using the sorter already set up previously. */ + def executePartition( + taskContext: TaskContext, + partitionIndex: Int, + sorter: UnsafeExternalRowSorter, + parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]]) taskContext.internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage) sortedIterator - }, preservesPartitioning = true) + } + + // Note: we need to set up the external sorter in each partition before computing + // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709). + new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter]( + child.execute(), preparePartition, executePartition, preservesPartitioning = true) } } From 17284db314f52bdb2065482b8a49656f7683d30a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 17:30:31 -0700 Subject: [PATCH 0107/1824] [SPARK-9228] [SQL] use tungsten.enabled in public for both of codegen/unsafe spark.sql.tungsten.enabled will be the default value for both codegen and unsafe, they are kept internally for debug/testing. cc marmbrus rxin Author: Davies Liu Closes #7998 from davies/tungsten and squashes the following commits: c1c16da [Davies Liu] update doc 1a47be1 [Davies Liu] use tungsten.enabled for both of codegen/unsafe (cherry picked from commit 4e70e8256ce2f45b438642372329eac7b1e9e8cf) Signed-off-by: Reynold Xin --- docs/sql-programming-guide.md | 6 +++--- .../scala/org/apache/spark/sql/SQLConf.scala | 20 ++++++++++++------- .../spark/sql/execution/SparkPlan.scala | 8 +++++++- .../spark/sql/execution/joins/HashJoin.scala | 3 ++- .../sql/execution/joins/HashOuterJoin.scala | 2 +- .../sql/execution/joins/HashSemiJoin.scala | 3 ++- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3ea77e82422f..6c317175d327 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1884,11 +1884,11 @@ that these options will be deprecated in future release as more optimizations ar - spark.sql.codegen + spark.sql.tungsten.enabled true - When true, code will be dynamically generated at runtime for expression evaluation in a specific - query. For some queries with complicated expression this option can lead to significant speed-ups. + When true, use the optimized Tungsten physical execution backend which explicitly manages memory + and dynamically generates bytecode for expression evaluation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f836122b3e0e..ef35c133d9cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -223,14 +223,21 @@ private[spark] object SQLConf { defaultValue = Some(200), doc = "The default number of partitions to use when shuffling data for joins or aggregations.") - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", defaultValue = Some(true), + doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + + "manages memory and dynamically generates bytecode for expression evaluation.") + + val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + defaultValue = Some(true), // use TUNGSTEN_ENABLED as default doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.") + " a specific query.", + isPublic = false) val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(true), - doc = "When true, use the new optimized Tungsten physical execution backend.") + defaultValue = Some(true), // use TUNGSTEN_ENABLED as default + doc = "When true, use the new optimized Tungsten physical execution backend.", + isPublic = false) val DIALECT = stringConf( "spark.sql.dialect", @@ -427,7 +434,6 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ - private[sql] class SQLConf extends Serializable with CatalystConf { import SQLConf._ @@ -474,11 +480,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 490428965a61..719ad432e2fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -55,12 +55,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled will be set by the desserializer after the constructor has run. + // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the + // constructor has run. val codegenEnabled: Boolean = if (sqlContext != null) { sqlContext.conf.codegenEnabled } else { false } + val unsafeEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.unsafeEnabled + } else { + false + } /** * Whether the "prepare" method is called. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5e9cd9fd2345..22d46d1c3e3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,7 +44,8 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + (self.codegenEnabled && self.unsafeEnabled + && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 346337e64245..701bd3cd8637 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -67,7 +67,7 @@ trait HashOuterJoin { } protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && joinType != FullOuter + (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 47a7d370f541..82dd6eb7e7ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,7 +33,8 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) + (self.codegenEnabled && self.unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) && UnsafeProjection.canSupport(left.schema) && UnsafeProjection.canSupport(right.schema)) From fe12277b40082585e40e1bdf6aa2ebcfe80ed83f Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 6 Aug 2015 21:03:47 -0700 Subject: [PATCH 0108/1824] Fix doc typo Straightforward fix on doc typo Author: Jeff Zhang Closes #8019 from zjffdu/master and squashes the following commits: aed6e64 [Jeff Zhang] Fix doc typo --- docs/tuning.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tuning.md b/docs/tuning.md index 572c7270e499..6936912a6be5 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -240,7 +240,7 @@ worth optimizing. ## Data Locality Data locality can have a major impact on the performance of Spark jobs. If data and the code that -operates on it are together than computation tends to be fast. But if code and data are separated, +operates on it are together then computation tends to be fast. But if code and data are separated, one must move to the other. Typically it is faster to ship serialized code from place to place than a chunk of data because code size is much smaller than data. Spark builds its scheduling around this general principle of data locality. From 672f467668da1cf20895ee57652489c306120288 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 6 Aug 2015 21:42:42 -0700 Subject: [PATCH 0109/1824] [SPARK-8057][Core]Call TaskAttemptContext.getTaskAttemptID using Reflection Someone may use the Spark core jar in the maven repo with hadoop 1. SPARK-2075 has already resolved the compatibility issue to support it. But `SparkHadoopMapRedUtil.commitTask` broke it recently. This PR uses Reflection to call `TaskAttemptContext.getTaskAttemptID` to fix the compatibility issue. Author: zsxwing Closes #6599 from zsxwing/SPARK-8057 and squashes the following commits: f7a343c [zsxwing] Remove the redundant import 6b7f1af [zsxwing] Call TaskAttemptContext.getTaskAttemptID using Reflection --- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 14 ++++++++++++++ .../spark/mapred/SparkHadoopMapRedUtil.scala | 3 ++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e06b06e06fb4..7e9dba42bebd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -34,6 +34,8 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} +import org.apache.hadoop.mapreduce.{TaskAttemptID => MapReduceTaskAttemptID} import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.annotation.DeveloperApi @@ -194,6 +196,18 @@ class SparkHadoopUtil extends Logging { method.invoke(context).asInstanceOf[Configuration] } + /** + * Using reflection to call `getTaskAttemptID` from TaskAttemptContext. If we directly + * call `TaskAttemptContext.getTaskAttemptID`, it will generate different byte codes + * for Hadoop 1.+ and Hadoop 2.+ because TaskAttemptContext is class in Hadoop 1.+ + * while it's interface in Hadoop 2.+. + */ + def getTaskAttemptIDFromTaskAttemptContext( + context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + val method = context.getClass.getMethod("getTaskAttemptID") + method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] + } + /** * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the * given path points to a file, return a single-element collection containing [[FileStatus]] of diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 87df42748be4..f405b732e472 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.util.{Utils => SparkUtils} @@ -93,7 +94,7 @@ object SparkHadoopMapRedUtil extends Logging { splitId: Int, attemptId: Int): Unit = { - val mrTaskAttemptID = mrTaskContext.getTaskAttemptID + val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) // Called after we have decided to commit def performCommit(): Unit = { From f0cda587fb80bf2f1ba53d35dc9dc87bf72ee338 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 6 Aug 2015 22:49:01 -0700 Subject: [PATCH 0110/1824] [SPARK-7550] [SQL] [MINOR] Fixes logs when persisting DataFrames Author: Cheng Lian Closes #8021 from liancheng/spark-7550/fix-logs and squashes the following commits: b7bd0ed [Cheng Lian] Fixes logs --- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 1523ebe9d549..7198a32df4a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -317,19 +317,17 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => logWarning { - val paths = relation.paths.mkString(", ") "Persisting partitioned data source relation into Hive metastore in " + s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + - paths.mkString("\n", "\n", "") + relation.paths.mkString("\n", "\n", "") } newSparkSQLSpecificMetastoreTable() case (Some(serde), relation: HadoopFsRelation) => logWarning { - val paths = relation.paths.mkString(", ") "Persisting data source relation with multiple input paths into Hive metastore in " + s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + - paths.mkString("\n", "\n", "") + relation.paths.mkString("\n", "\n", "") } newSparkSQLSpecificMetastoreTable() From 7aaed1b114751a24835204b8c588533d5c5ffaf0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 6 Aug 2015 22:52:23 -0700 Subject: [PATCH 0111/1824] [SPARK-8862][SQL]Support multiple SQLContexts in Web UI This is a follow-up PR to solve the UI issue when there are multiple SQLContexts. Each SQLContext has a separate tab and contains queries which are executed by this SQLContext. multiple sqlcontexts Author: zsxwing Closes #7962 from zsxwing/multi-sqlcontext-ui and squashes the following commits: cf661e1 [zsxwing] sql -> SQL 39b0c97 [zsxwing] Support multiple SQLContexts in Web UI --- .../org/apache/spark/sql/ui/AllExecutionsPage.scala | 2 +- .../main/scala/org/apache/spark/sql/ui/SQLTab.scala | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala index 727fc4b37fa4..cb7ca60b2fe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala @@ -178,7 +178,7 @@ private[ui] abstract class ExecutionTable( "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) private def executionURL(executionID: Long): String = - "%s/sql/execution?id=%s".format(UIUtils.prependBaseUri(parent.basePath), executionID) + s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala index a9e522630397..3bba0afaf14e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.ui +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) - extends SparkUITab(sparkUI, "sql") with Logging { - + extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { val parent = sparkUI val listener = sqlContext.listener @@ -38,4 +39,11 @@ private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) private[sql] object SQLTab { private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/ui/static" + + private val nextTabId = new AtomicInteger(0) + + private def nextTabName: String = { + val nextId = nextTabId.getAndIncrement() + if (nextId == 0) "SQL" else s"SQL${nextId}" + } } From 4309262ec9146d7158ee9957a128bb152289d557 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 6 Aug 2015 23:18:29 -0700 Subject: [PATCH 0112/1824] [SPARK-9700] Pick default page size more intelligently. Previously, we use 64MB as the default page size, which was way too big for a lot of Spark applications (especially for single node). This patch changes it so that the default page size, if unset by the user, is determined by the number of cores available and the total execution memory available. Author: Reynold Xin Closes #8012 from rxin/pagesize and squashes the following commits: 16f4756 [Reynold Xin] Fixed failing test. 5afd570 [Reynold Xin] private... 0d5fb98 [Reynold Xin] Update default value. 674a6cd [Reynold Xin] Address review feedback. dc00e05 [Reynold Xin] Merge with master. 73ebdb6 [Reynold Xin] [SPARK-9700] Pick default page size more intelligently. --- R/run-tests.sh | 2 +- .../unsafe/UnsafeShuffleExternalSorter.java | 3 +- .../spark/unsafe/map/BytesToBytesMap.java | 8 +-- .../unsafe/sort/UnsafeExternalSorter.java | 1 - .../scala/org/apache/spark/SparkConf.scala | 7 +++ .../scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../spark/shuffle/ShuffleMemoryManager.scala | 53 +++++++++++++++++-- .../unsafe/UnsafeShuffleWriterSuite.java | 5 +- .../map/AbstractBytesToBytesMapSuite.java | 6 +-- .../sort/UnsafeExternalSorterSuite.java | 4 +- .../shuffle/ShuffleMemoryManagerSuite.scala | 14 ++--- python/pyspark/java_gateway.py | 1 - .../TungstenAggregationIterator.scala | 2 +- .../sql/execution/joins/HashedRelation.scala | 16 +++--- .../org/apache/spark/sql/execution/sort.scala | 4 +- ...ypes.scala => ParquetTypesConverter.scala} | 0 .../execution/TestShuffleMemoryManager.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 1 - .../spark/unsafe/array/ByteArrayMethods.java | 6 +++ 20 files changed, 93 insertions(+), 46 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/parquet/{ParquetTypes.scala => ParquetTypesConverter.scala} (100%) diff --git a/R/run-tests.sh b/R/run-tests.sh index 18a1e13bdc65..e82ad0ba2cd0 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index bf4eaa59ff58..f6e0913a7a0b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -115,8 +115,7 @@ public UnsafeShuffleExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.pageSizeBytes = (int) Math.min( - PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, - conf.getSizeAsBytes("spark.buffer.pageSize", "64m")); + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5ac3736ac62a..0636ae7c8df1 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -642,7 +642,7 @@ public boolean putNewKey( private void allocate(int capacity) { assert (capacity >= 0); // The capacity needs to be divisible by 64 so that our bit set can be sized properly - capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64); + capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); @@ -770,10 +770,4 @@ void growAndRehash() { timeSpentResizingNs += System.nanoTime() - resizeStartTime; } } - - /** Returns the next number greater or equal num that is power of 2. */ - private static long nextPowerOf2(long num) { - final long highBit = Long.highestOneBit(num); - return (highBit == num) ? num : highBit << 1; - } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 4c54ba4bce40..5ebbf9b068fd 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -127,7 +127,6 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.fileBufferSizeBytes = 32 * 1024; - // this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); this.pageSizeBytes = pageSizeBytes; this.writeMetrics = new ShuffleWriteMetrics(); diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 08bab4bf2739..8ff154fb5e33 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -249,6 +249,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.byteStringAsBytes(get(key, defaultValue)) } + /** + * Get a size parameter as bytes, falling back to a default if not set. + */ + def getSizeAsBytes(key: String, defaultValue: Long): Long = { + Utils.byteStringAsBytes(get(key, defaultValue + "B")) + } + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0c0705325b16..566268643690 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -629,7 +629,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * [[org.apache.spark.SparkContext.setLocalProperty]]. */ def getLocalProperty(key: String): String = - Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) + Option(localProperties.get).map(_.getProperty(key)).orNull /** Set a human readable description of the current job. */ def setJobDescription(value: String) { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index adfece4d6e7c..a796e7285019 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -324,7 +324,7 @@ object SparkEnv extends Logging { val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores) val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index e3d229cc9982..8c3a72644c38 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,6 +19,9 @@ package org.apache.spark.shuffle import scala.collection.mutable +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** @@ -34,11 +37,19 @@ import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. + * + * Use `ShuffleMemoryManager.create()` factory method to create a new instance. + * + * @param maxMemory total amount of memory available for execution, in bytes. + * @param pageSizeBytes number of bytes for each page, by default. */ -private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes +private[spark] +class ShuffleMemoryManager protected ( + val maxMemory: Long, + val pageSizeBytes: Long) + extends Logging { - def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes private def currentTaskAttemptId(): Long = { // In case this is called on the driver, return an invalid task attempt id. @@ -124,15 +135,49 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { } } + private[spark] object ShuffleMemoryManager { + + def create(conf: SparkConf, numCores: Int): ShuffleMemoryManager = { + val maxMemory = ShuffleMemoryManager.getMaxMemory(conf) + val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores) + new ShuffleMemoryManager(maxMemory, pageSize) + } + + def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = { + new ShuffleMemoryManager(maxMemory, pageSizeBytes) + } + + @VisibleForTesting + def createForTesting(maxMemory: Long): ShuffleMemoryManager = { + new ShuffleMemoryManager(maxMemory, 4 * 1024 * 1024) + } + /** * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction * of the memory pool and a safety factor since collections can sometimes grow bigger than * the size we target before we estimate their sizes again. */ - def getMaxMemory(conf: SparkConf): Long = { + private def getMaxMemory(conf: SparkConf): Long = { val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong } + + /** + * Sets the page size, in bytes. + * + * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value + * by looking at the number of cores available to the process, and the total amount of memory, + * and then divide it by a factor of safety. + */ + private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = { + val minPageSize = 1L * 1024 * 1024 // 1MB + val maxPageSize = 64L * minPageSize // 64MB + val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() + val safetyFactor = 8 + val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) + val default = math.min(maxPageSize, math.max(minPageSize, size)) + conf.getSizeAsBytes("spark.buffer.pageSize", default) + } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 98c32bbc298d..c68354ba49a4 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -115,6 +115,7 @@ public void setUp() throws IOException { taskMetrics = new TaskMetrics(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -549,14 +550,14 @@ public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; final long pageSizeBytes = 256; final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; - final SparkConf conf = new SparkConf().set("spark.buffer.pageSize", pageSizeBytes + "b"); + when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( blockManager, shuffleBlockResolver, taskMemoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle(0, 1, shuffleDep), + new UnsafeShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 3c5003380162..0b11562980b8 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -48,7 +48,7 @@ public abstract class AbstractBytesToBytesMapSuite { @Before public void setup() { - shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); + shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, PAGE_SIZE_BYTES); taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); // Mocked memory manager for tests that check the maximum array size, since actually allocating // such large arrays will cause us to run out of memory in our tests. @@ -441,7 +441,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() { @Test public void failureToAllocateFirstPage() { - shuffleMemoryManager = new ShuffleMemoryManager(1024); + shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024); BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); try { @@ -461,7 +461,7 @@ public void failureToAllocateFirstPage() { @Test public void failureToGrow() { - shuffleMemoryManager = new ShuffleMemoryManager(1024 * 10); + shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024 * 10); BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024); try { boolean success = true; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index f5300373d87e..83049b8a21fc 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -102,7 +102,7 @@ public void setUp() { MockitoAnnotations.initMocks(this); sparkConf = new SparkConf(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); - shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); + shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, pageSizeBytes); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); @@ -237,7 +237,7 @@ public void testSortingEmptyArrays() throws Exception { @Test public void spillingOccursInResponseToMemoryPressure() throws Exception { - shuffleMemoryManager = new ShuffleMemoryManager(pageSizeBytes * 2); + shuffleMemoryManager = ShuffleMemoryManager.create(pageSizeBytes * 2, pageSizeBytes); final UnsafeExternalSorter sorter = newSorter(); final int numRecords = (int) pageSizeBytes / 4; for (int i = 0; i <= numRecords; i++) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index f495b6a03795..6d45b1a101be 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -24,7 +24,7 @@ import org.mockito.Mockito._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { @@ -50,7 +50,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } test("single task requesting memory") { - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) assert(manager.tryToAcquire(100L) === 100L) assert(manager.tryToAcquire(400L) === 400L) @@ -72,7 +72,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // Two threads request 500 bytes first, wait for each other to get it, and then request // 500 more; we should immediately return 0 as both are now at 1 / N - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Result1 = -1L @@ -124,7 +124,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // Two tasks request 250 bytes first, wait for each other to get it, and then request // 500 more; we should only grant 250 bytes to each of them on this second request - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Result1 = -1L @@ -176,7 +176,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // for a bit and releases 250 bytes, which should then be granted to t2. Further requests // by t2 will return false right away because it now has 1 / 2N of the memory. - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Requested = false @@ -241,7 +241,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases all its memory. t2 should now be able to grab all the memory. - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) class State { var t1Requested = false @@ -307,7 +307,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } test("tasks should not be granted a negative size") { - val manager = new ShuffleMemoryManager(1000L) + val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) manager.tryToAcquire(700L) val latch = new CountDownLatch(1) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 60be85e53e2a..cd4c55f79f18 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -54,7 +54,6 @@ def launch_gateway(): if os.environ.get("SPARK_TESTING"): submit_args = ' '.join([ "--conf spark.ui.enabled=false", - "--conf spark.buffer.pageSize=4mb", submit_args ]) command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index b9d44aace100..4d5e98a3e90c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -342,7 +342,7 @@ class TungstenAggregationIterator( TaskContext.get.taskMemoryManager(), SparkEnv.get.shuffleMemoryManager, 1024 * 16, // initial capacity - SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"), + SparkEnv.get.shuffleMemoryManager.pageSizeBytes, false // disable tracking of performance metrics ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 3f257ecdd156..953abf409f22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -282,17 +282,15 @@ private[joins] final class UnsafeHashedRelation( // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val pageSizeBytes = Option(SparkEnv.get).map(_.shuffleMemoryManager.pageSizeBytes) + .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) + // Dummy shuffle memory manager which always grants all memory allocation requests. // We use this because it doesn't make sense count shared broadcast variables' memory usage // towards individual tasks' quotas. In the future, we should devise a better way of handling // this. - val shuffleMemoryManager = new ShuffleMemoryManager(new SparkConf()) { - override def tryToAcquire(numBytes: Long): Long = numBytes - override def release(numBytes: Long): Unit = {} - } - - val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - .getSizeAsBytes("spark.buffer.pageSize", "64m") + val shuffleMemoryManager = + ShuffleMemoryManager.create(maxMemory = Long.MaxValue, pageSizeBytes = pageSizeBytes) binaryMap = new BytesToBytesMap( taskMemoryManager, @@ -306,11 +304,11 @@ private[joins] final class UnsafeHashedRelation( while (i < nKeys) { val keySize = in.readInt() val valuesSize = in.readInt() - if (keySize > keyBuffer.size) { + if (keySize > keyBuffer.length) { keyBuffer = new Array[Byte](keySize) } in.readFully(keyBuffer, 0, keySize) - if (valuesSize > valuesBuffer.size) { + if (valuesSize > valuesBuffer.length) { valuesBuffer = new Array[Byte](valuesSize) } in.readFully(valuesBuffer, 0, valuesSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 7f69cdb08aa7..e31693047012 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -122,7 +122,7 @@ case class TungstenSort( protected override def doExecute(): RDD[InternalRow] = { val schema = child.schema val childOutput = child.output - val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") + val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes /** * Set up the sorter in each partition before computing the parent partition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala rename to sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala index 53de2d0f0771..48c3938ff87b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -22,7 +22,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager /** * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. */ -class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue) { +class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1024 * 1024) { private var oom = false override def tryToAcquire(numBytes: Long): Long = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 167086db5bfe..296cc5c5e0b0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -52,7 +52,6 @@ object TestHive .set("spark.sql.test", "") .set("spark.sql.hive.metastore.barrierPrefixes", "org.apache.spark.sql.hive.execution.PairSerDe") - .set("spark.buffer.pageSize", "4m") // SPARK-8910 .set("spark.ui.enabled", "false"))) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index cf693d01a4f5..70b81ce015dd 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -25,6 +25,12 @@ private ByteArrayMethods() { // Private constructor, since this class only contains static methods. } + /** Returns the next number greater or equal num that is power of 2. */ + public static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } + public static int roundNumberOfBytesToNearestWord(int numBytes) { int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { From 15bd6f338dff4bcab4a1a3a2c568655022e49c32 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 6 Aug 2015 23:40:38 -0700 Subject: [PATCH 0113/1824] [SPARK-9453] [SQL] support records larger than page size in UnsafeShuffleExternalSorter This patch follows exactly #7891 (except testing) Author: Davies Liu Closes #8005 from davies/larger_record and squashes the following commits: f9c4aff [Davies Liu] address comments 9de5c72 [Davies Liu] support records larger than page size in UnsafeShuffleExternalSorter --- .../unsafe/UnsafeShuffleExternalSorter.java | 143 +++++++++++------- .../shuffle/unsafe/UnsafeShuffleWriter.java | 10 +- .../unsafe/UnsafeShuffleWriterSuite.java | 60 ++------ 3 files changed, 103 insertions(+), 110 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index f6e0913a7a0b..925b60a14588 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -17,10 +17,10 @@ package org.apache.spark.shuffle.unsafe; +import javax.annotation.Nullable; import java.io.File; import java.io.IOException; import java.util.LinkedList; -import javax.annotation.Nullable; import scala.Tuple2; @@ -34,8 +34,11 @@ import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.storage.*; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; @@ -68,7 +71,7 @@ final class UnsafeShuffleExternalSorter { private final int pageSizeBytes; @VisibleForTesting final int maxRecordSizeBytes; - private final TaskMemoryManager memoryManager; + private final TaskMemoryManager taskMemoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; @@ -91,7 +94,7 @@ final class UnsafeShuffleExternalSorter { private long peakMemoryUsedBytes; // These variables are reset after spilling: - @Nullable private UnsafeShuffleInMemorySorter sorter; + @Nullable private UnsafeShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; @@ -105,7 +108,7 @@ public UnsafeShuffleExternalSorter( int numPartitions, SparkConf conf, ShuffleWriteMetrics writeMetrics) throws IOException { - this.memoryManager = memoryManager; + this.taskMemoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; @@ -133,7 +136,7 @@ private void initializeForWriting() throws IOException { throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } - this.sorter = new UnsafeShuffleInMemorySorter(initialSize); + this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize); } /** @@ -160,7 +163,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { // This call performs the actual sort. final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = - sorter.getSortedIterator(); + inMemSorter.getSortedIterator(); // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this // after SPARK-5581 is fixed. @@ -206,8 +209,8 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); - final Object recordPage = memoryManager.getPage(recordPointer); - final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); + final Object recordPage = taskMemoryManager.getPage(recordPointer); + final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { @@ -269,9 +272,9 @@ void spill() throws IOException { spills.size() > 1 ? " times" : " time"); writeSortedFile(false); - final long sorterMemoryUsage = sorter.getMemoryUsage(); - sorter = null; - shuffleMemoryManager.release(sorterMemoryUsage); + final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + shuffleMemoryManager.release(inMemSorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); @@ -283,7 +286,7 @@ private long getMemoryUsage() { for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize; + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; } private void updatePeakMemoryUsed() { @@ -305,7 +308,7 @@ private long freeMemory() { updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { - memoryManager.freePage(block); + taskMemoryManager.freePage(block); shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); } @@ -319,54 +322,53 @@ private long freeMemory() { /** * Force all memory and spill files to be deleted; called by shuffle error-handling code. */ - public void cleanupAfterError() { + public void cleanupResources() { freeMemory(); for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { logger.error("Unable to delete spill file {}", spill.file.getPath()); } } - if (sorter != null) { - shuffleMemoryManager.release(sorter.getMemoryUsage()); - sorter = null; + if (inMemSorter != null) { + shuffleMemoryManager.release(inMemSorter.getMemoryUsage()); + inMemSorter = null; } } /** - * Checks whether there is enough space to insert a new record into the sorter. - * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. - - * @return true if the record can be inserted without requiring more allocations, false otherwise. - */ - private boolean haveSpaceForRecord(int requiredSpace) { - assert (requiredSpace > 0); - return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); - } - - /** - * Allocates more memory in order to insert an additional record. This will request additional - * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be - * obtained. - * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. */ - private void allocateSpaceForRecord(int requiredSpace) throws IOException { - if (!sorter.hasSpaceForAnotherRecord()) { + private void growPointerArrayIfNecessary() throws IOException { + assert(inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { logger.debug("Attempting to expand sort pointer array"); - final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); if (memoryAcquired < memoryToGrowPointerArray) { shuffleMemoryManager.release(memoryAcquired); spill(); } else { - sorter.expandPointerArray(); + inMemSorter.expandPointerArray(); shuffleMemoryManager.release(oldPointerArrayMemoryUsage); } } + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. This must be less than or equal to the page size (records + * that exceed the page size are handled via a different code path which uses + * special overflow pages). + */ + private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { + growPointerArrayIfNecessary(); if (requiredSpace > freeSpaceInCurrentPage) { logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, freeSpaceInCurrentPage); @@ -387,7 +389,7 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(pageSizeBytes); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); @@ -403,27 +405,58 @@ public void insertRecord( long recordBaseOffset, int lengthInBytes, int partitionId) throws IOException { + + growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. final int totalSpaceRequired = lengthInBytes + 4; - if (!haveSpaceForRecord(totalSpaceRequired)) { - allocateSpaceForRecord(totalSpaceRequired); + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; } + final Object dataPageBaseObject = dataPage.getBaseObject(); final long recordAddress = - memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); - final Object dataPageBaseObject = currentPage.getBaseObject(); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); - currentPagePosition += 4; - freeSpaceInCurrentPage -= 4; + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + dataPagePosition += 4; PlatformDependent.copyMemory( recordBaseObject, recordBaseOffset, dataPageBaseObject, - currentPagePosition, + dataPagePosition, lengthInBytes); - currentPagePosition += lengthInBytes; - freeSpaceInCurrentPage -= lengthInBytes; - sorter.insertRecord(recordAddress, partitionId); + assert(inMemSorter != null); + inMemSorter.insertRecord(recordAddress, partitionId); } /** @@ -435,14 +468,14 @@ public void insertRecord( */ public SpillInfo[] closeAndGetSpills() throws IOException { try { - if (sorter != null) { + if (inMemSorter != null) { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { - cleanupAfterError(); + cleanupResources(); throw e; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 6e2eeb37c86f..02084f9122e0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -17,17 +17,17 @@ package org.apache.spark.shuffle.unsafe; +import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; import java.util.Iterator; -import javax.annotation.Nullable; import scala.Option; import scala.Product2; import scala.collection.JavaConversions; +import scala.collection.immutable.Map; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; -import scala.collection.immutable.Map; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; @@ -38,10 +38,10 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.LZFCompressionCodec; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -178,7 +178,7 @@ public void write(scala.collection.Iterator> records) throws IOEx } finally { if (sorter != null) { try { - sorter.cleanupAfterError(); + sorter.cleanupResources(); } catch (Exception e) { // Only throw this error if we won't be masking another // error. @@ -482,7 +482,7 @@ public Option stop(boolean success) { if (sorter != null) { // If sorter is non-null, then this implies that we called stop() in response to an error, // so we need to clean up memory and spill files created by the sorter - sorter.cleanupAfterError(); + sorter.cleanupResources(); } } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index c68354ba49a4..94650be536b5 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -475,62 +475,22 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception @Test public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { - // Use a custom serializer so that we have exact control over the size of serialized data. - final Serializer byteArraySerializer = new Serializer() { - @Override - public SerializerInstance newInstance() { - return new SerializerInstance() { - @Override - public SerializationStream serializeStream(final OutputStream s) { - return new SerializationStream() { - @Override - public void flush() { } - - @Override - public SerializationStream writeObject(T t, ClassTag ev1) { - byte[] bytes = (byte[]) t; - try { - s.write(bytes); - } catch (IOException e) { - throw new RuntimeException(e); - } - return this; - } - - @Override - public void close() { } - }; - } - public ByteBuffer serialize(T t, ClassTag ev1) { return null; } - public DeserializationStream deserializeStream(InputStream s) { return null; } - public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; } - public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; } - }; - } - }; - when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer)); final UnsafeShuffleWriter writer = createWriter(false); - // Insert a record and force a spill so that there's something to clean up: - writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); - writer.forceSorterToSpill(); + final ArrayList> dataToWrite = new ArrayList>(); + dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(new byte[1]))); // We should be able to write a record that's right _at_ the max record size final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; new Random(42).nextBytes(atMaxRecordSize); - writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); - writer.forceSorterToSpill(); - // Inserting a record that's larger than the max record size should fail: + dataToWrite.add(new Tuple2(2, ByteBuffer.wrap(atMaxRecordSize))); + // Inserting a record that's larger than the max record size final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; new Random(42).nextBytes(exceedsMaxRecordSize); - Product2 hugeRecord = - new Tuple2(new byte[0], exceedsMaxRecordSize); - try { - // Here, we write through the public `write()` interface instead of the test-only - // `insertRecordIntoSorter` interface: - writer.write(Collections.singletonList(hugeRecord).iterator()); - fail("Expected exception to be thrown"); - } catch (IOException e) { - // Pass - } + dataToWrite.add(new Tuple2(3, ByteBuffer.wrap(exceedsMaxRecordSize))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } From e57d6b56137bf3557efe5acea3ad390c1987b257 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Aug 2015 00:00:43 -0700 Subject: [PATCH 0114/1824] [SPARK-9683] [SQL] copy UTF8String when convert unsafe array/map to safe When we convert unsafe row to safe row, we will do copy if the column is struct or string type. However, the string inside unsafe array/map are not copied, which may cause problems. Author: Wenchen Fan Closes #7990 from cloud-fan/copy and squashes the following commits: c13d1e3 [Wenchen Fan] change test name fe36294 [Wenchen Fan] we should deep copy UTF8String when convert unsafe row to safe row --- .../sql/catalyst/expressions/FromUnsafe.scala | 3 ++ .../execution/RowFormatConvertersSuite.scala | 38 ++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala index 3caf0fb3410c..9b960b136f98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String case class FromUnsafe(child: Expression) extends UnaryExpression with ExpectsInputTypes with CodegenFallback { @@ -52,6 +53,8 @@ case class FromUnsafe(child: Expression) extends UnaryExpression } new GenericArrayData(result) + case StringType => value.asInstanceOf[UTF8String].clone() + case MapType(kt, vt, _) => val map = value.asInstanceOf[UnsafeMapData] val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 707cd9c6d939..8208b25b5708 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType} +import org.apache.spark.unsafe.types.UTF8String class RowFormatConvertersSuite extends SparkPlanTest { @@ -87,4 +91,36 @@ class RowFormatConvertersSuite extends SparkPlanTest { input.map(Row.fromTuple) ) } + + test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { + SparkPlan.currentContext.set(TestSQLContext) + val schema = ArrayType(StringType) + val rows = (1 to 100).map { i => + InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) + } + val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) + + val plan = + DummyPlan( + ConvertToSafe( + ConvertToUnsafe(relation))) + assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) + } +} + +case class DummyPlan(child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + // cache all strings to make sure we have deep copied UTF8String inside incoming + // safe InternalRow. + val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] + iter.foreach { row => + strings += row.getArray(0).getUTF8String(0) + } + strings.map(InternalRow(_)).iterator + } + } + + override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) } From ebfd91c542aaead343cb154277fcf9114382fee7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 7 Aug 2015 00:09:58 -0700 Subject: [PATCH 0115/1824] [SPARK-9467][SQL]Add SQLMetric to specialize accumulators to avoid boxing This PR adds SQLMetric/SQLMetricParam/SQLMetricValue to specialize accumulators to avoid boxing. All SQL metrics should use these classes rather than `Accumulator`. Author: zsxwing Closes #7996 from zsxwing/sql-accu and squashes the following commits: 14a5f0a [zsxwing] Address comments 367ca23 [zsxwing] Use localValue directly to avoid changing Accumulable 42f50c3 [zsxwing] Add SQLMetric to specialize accumulators to avoid boxing --- .../scala/org/apache/spark/Accumulators.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 15 -- .../spark/sql/execution/SparkPlan.scala | 33 ++-- .../spark/sql/execution/basicOperators.scala | 11 +- .../apache/spark/sql/metric/SQLMetrics.scala | 149 ++++++++++++++++++ .../org/apache/spark/sql/ui/SQLListener.scala | 17 +- .../apache/spark/sql/ui/SparkPlanGraph.scala | 8 +- .../spark/sql/metric/SQLMetricsSuite.scala | 145 +++++++++++++++++ .../spark/sql/ui/SQLListenerSuite.scala | 5 +- 9 files changed, 338 insertions(+), 47 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 462d5c96d480..064246dfa7fc 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -257,7 +257,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa */ class Accumulator[T] private[spark] ( @transient private[spark] val initialValue: T, - private[spark] val param: AccumulatorParam[T], + param: AccumulatorParam[T], name: Option[String], internal: Boolean) extends Accumulable[T, T](initialValue, param, name, internal) { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 566268643690..9ced44131b0d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1238,21 +1238,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli acc } - /** - * Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display - * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the - * driver can access the accumulator's `value`. The latest local value of such accumulator will be - * sent back to the driver via heartbeats. - * - * @tparam T type that can be added to the accumulator, must be thread safe - */ - private[spark] def internalAccumulator[T](initialValue: T, name: String)( - implicit param: AccumulatorParam[T]): Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name), internal = true) - cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } - /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 719ad432e2fe..1915496d1620 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Accumulator, Logging} +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.SQLContext @@ -32,6 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.DataType object SparkPlan { @@ -84,22 +85,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ protected[sql] def trackNumOfRowsEnabled: Boolean = false - private lazy val numOfRowsAccumulator = sparkContext.internalAccumulator(0L, "number of rows") + private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] = + if (trackNumOfRowsEnabled) { + Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + } + else { + Map.empty + } /** - * Return all accumulators containing metrics of this SparkPlan. + * Return all metrics containing metrics of this SparkPlan. */ - private[sql] def accumulators: Map[String, Accumulator[_]] = if (trackNumOfRowsEnabled) { - Map("numRows" -> numOfRowsAccumulator) - } else { - Map.empty - } + private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics + + /** + * Return a IntSQLMetric according to the name. + */ + private[sql] def intMetric(name: String): IntSQLMetric = + metrics(name).asInstanceOf[IntSQLMetric] /** - * Return the accumulator according to the name. + * Return a LongSQLMetric according to the name. */ - private[sql] def accumulator[T](name: String): Accumulator[T] = - accumulators(name).asInstanceOf[Accumulator[T]] + private[sql] def longMetric(name: String): LongSQLMetric = + metrics(name).asInstanceOf[LongSQLMetric] // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ @@ -148,7 +157,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() if (trackNumOfRowsEnabled) { - val numRows = accumulator[Long]("numRows") + val numRows = longMetric("numRows") doExecute().map { row => numRows += 1 row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f4677b4ee86b..0680f31d40f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator @@ -81,13 +82,13 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - private[sql] override lazy val accumulators = Map( - "numInputRows" -> sparkContext.internalAccumulator(0L, "number of input rows"), - "numOutputRows" -> sparkContext.internalAccumulator(0L, "number of output rows")) + private[sql] override lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val numInputRows = accumulator[Long]("numInputRows") - val numOutputRows = accumulator[Long]("numOutputRows") + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => val predicate = newPredicate(condition, child.output) iter.filter { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala new file mode 100644 index 000000000000..3b907e5da789 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala @@ -0,0 +1,149 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.metric + +import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + * + * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. + */ +private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( + name: String, val param: SQLMetricParam[R, T]) + extends Accumulable[R, T](param.zero, param, Some(name), true) + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] { + + def zero: R +} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricValue[T] extends Serializable { + + def value: T + + override def toString: String = value.toString +} + +/** + * A wrapper of Long to avoid boxing and unboxing when using Accumulator + */ +private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { + + def add(incr: Long): LongSQLMetricValue = { + _value += incr + this + } + + // Although there is a boxing here, it's fine because it's only called in SQLListener + override def value: Long = _value +} + +/** + * A wrapper of Int to avoid boxing and unboxing when using Accumulator + */ +private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] { + + def add(term: Int): IntSQLMetricValue = { + _value += term + this + } + + // Although there is a boxing here, it's fine because it's only called in SQLListener + override def value: Int = _value +} + +/** + * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's + * `+=` and `add`. + */ +private[sql] class LongSQLMetric private[metric](name: String) + extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) { + + override def +=(term: Long): Unit = { + localValue.add(term) + } + + override def add(term: Long): Unit = { + localValue.add(term) + } +} + +/** + * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's + * `+=` and `add`. + */ +private[sql] class IntSQLMetric private[metric](name: String) + extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) { + + override def +=(term: Int): Unit = { + localValue.add(term) + } + + override def add(term: Int): Unit = { + localValue.add(term) + } +} + +private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { + + override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) + + override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue = + r1.add(r2.value) + + override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero + + override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) +} + +private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] { + + override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t) + + override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue = + r1.add(r2.value) + + override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero + + override def zero: IntSQLMetricValue = new IntSQLMetricValue(0) +} + +private[sql] object SQLMetrics { + + def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = { + val acc = new IntSQLMetric(name) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } + + def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { + val acc = new LongSQLMetric(name) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala index e7b1dd1ffac6..2fd4fc658d06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala @@ -21,11 +21,12 @@ import scala.collection.mutable import com.google.common.annotations.VisibleForTesting -import org.apache.spark.{AccumulatorParam, JobExecutionStatus, Logging} +import org.apache.spark.{JobExecutionStatus, Logging} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { @@ -36,8 +37,6 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit // Old data in the following fields must be removed in "trimExecutionsIfNecessary". // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data - - // VisibleForTesting private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]() /** @@ -270,9 +269,10 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { accumulatorUpdate } - }.filter { case (id, _) => executionUIData.accumulatorMetrics.keySet(id) } + }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => - executionUIData.accumulatorMetrics(accumulatorId).accumulatorParam) + executionUIData.accumulatorMetrics(accumulatorId).metricParam). + mapValues(_.asInstanceOf[SQLMetricValue[_]].value) case None => // This execution has been dropped Map.empty @@ -281,10 +281,11 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit private def mergeAccumulatorUpdates( accumulatorUpdates: Seq[(Long, Any)], - paramFunc: Long => AccumulatorParam[Any]): Map[Long, Any] = { + paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = { accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => val param = paramFunc(accumulatorId) - (accumulatorId, values.map(_._2).reduceLeft(param.addInPlace)) + (accumulatorId, + values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace)) } } @@ -336,7 +337,7 @@ private[ui] class SQLExecutionUIData( private[ui] case class SQLPlanMetric( name: String, accumulatorId: Long, - accumulatorParam: AccumulatorParam[Any]) + metricParam: SQLMetricParam[SQLMetricValue[Any], Any]) /** * Store all accumulatorUpdates for all tasks in a Spark stage. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala index 7910c163ba45..1ba50b95becc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.AccumulatorParam import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. @@ -61,9 +61,9 @@ private[sql] object SparkPlanGraph { nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.accumulators.toSeq.map { case (key, accumulator) => - SQLPlanMetric(accumulator.name.getOrElse(key), accumulator.id, - accumulator.param.asInstanceOf[AccumulatorParam[Any]]) + val metrics = plan.metrics.toSeq.map { case (key, metric) => + SQLPlanMetric(metric.name.getOrElse(key), metric.id, + metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) } val node = SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala new file mode 100644 index 000000000000..d22160f5384f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala @@ -0,0 +1,145 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.metric + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.collection.mutable + +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.util.Utils + + +class SQLMetricsSuite extends SparkFunSuite { + + test("LongSQLMetric should not box Long") { + val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("IntSQLMetric should not box Int") { + val l = SQLMetrics.createIntMetric(TestSQLContext.sparkContext, "Int") + val f = () => { l += 1 } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("Normal accumulator should do boxing") { + // We need this test to make sure BoxingFinder works. + val l = TestSQLContext.sparkContext.accumulator(0L) + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") + } + } +} + +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * If `method` is null, search all methods of this class recursively to find if they do some boxing. + * If `method` is specified, only search this method of the class to speed up the searching. + * + * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. + */ +private class BoxingFinder( + method: MethodIdentifier[_] = null, + val boxingInvokes: mutable.Set[String] = mutable.Set.empty, + visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) + extends ClassVisitor(ASM4) { + + private val primitiveBoxingClassName = + Set("java/lang/Long", + "java/lang/Double", + "java/lang/Integer", + "java/lang/Float", + "java/lang/Short", + "java/lang/Character", + "java/lang/Byte", + "java/lang/Boolean") + + override def visitMethod( + access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): + MethodVisitor = { + if (method != null && (method.name != name || method.desc != desc)) { + // If method is specified, skip other methods. + return new MethodVisitor(ASM4) {} + } + + new MethodVisitor(ASM4) { + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { + if (primitiveBoxingClassName.contains(owner)) { + // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) + boxingInvokes.add(s"$owner.$name") + } + } else { + // scalastyle:off classforname + val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, + Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + val m = MethodIdentifier(classOfMethodOwner, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) + } + } + } + } + } + } +} + +private object BoxingFinder { + + def getClassReader(cls: Class[_]): Option[ClassReader] = { + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + val baos = new ByteArrayOutputStream(128) + // Copy data over, before delegating to ClassReader - + // else we can run out of open file handles. + Utils.copyStream(resourceStream, baos, true) + // ASM4 doesn't support Java 8 classes, which requires ASM5. + // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), + // then ClassReader will throw IllegalArgumentException, + // However, since this is only for testing, it's safe to skip these classes. + try { + Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) + } catch { + case _: IllegalArgumentException => None + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala index f1fcaf59532b..69a561e16aa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala @@ -21,6 +21,7 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution @@ -65,9 +66,9 @@ class SQLListenerSuite extends SparkFunSuite { speculative = false ) - private def createTaskMetrics(accumulatorUpdates: Map[Long, Any]): TaskMetrics = { + private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { val metrics = new TaskMetrics - metrics.setAccumulatorsUpdater(() => accumulatorUpdates) + metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) metrics.updateAccumulators() metrics } From 76eaa701833a2ff23b50147d70ced41e85719572 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 11:02:53 -0700 Subject: [PATCH 0116/1824] [SPARK-9674][SPARK-9667] Remove SparkSqlSerializer2 It is now subsumed by various Tungsten operators. Author: Reynold Xin Closes #7981 from rxin/SPARK-9674 and squashes the following commits: 144f96e [Reynold Xin] Re-enable test 58b7332 [Reynold Xin] Disable failing list. fb797e3 [Reynold Xin] Match all UDTs. be9f243 [Reynold Xin] Updated if. 71fc99c [Reynold Xin] [SPARK-9674][SPARK-9667] Remove GeneratedAggregate & SparkSqlSerializer2. --- .../scala/org/apache/spark/sql/SQLConf.scala | 6 - .../apache/spark/sql/execution/Exchange.scala | 48 +- .../sql/execution/SparkSqlSerializer2.scala | 426 ------------------ .../execution/SparkSqlSerializer2Suite.scala | 221 --------- 4 files changed, 24 insertions(+), 677 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ef35c133d9cc..45d3d8c86351 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -416,10 +416,6 @@ private[spark] object SQLConf { val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", defaultValue = Some(true), doc = "") - val USE_SQL_SERIALIZER2 = booleanConf( - "spark.sql.useSerializer2", - defaultValue = Some(true), isPublic = false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -488,8 +484,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) - private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6ea5eeedf1bb..60087f2ca4a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -39,21 +40,34 @@ import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEn @DeveloperApi case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { - override def outputPartitioning: Partitioning = newPartitioning - - override def output: Seq[Attribute] = child.output - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange" - override def canProcessSafeRows: Boolean = true - - override def canProcessUnsafeRows: Boolean = { + /** + * Returns true iff the children outputs aggregate UDTs that are not part of the SQL type. + * This only happens with the old aggregate implementation and should be removed in 1.6. + */ + private lazy val tungstenMode: Boolean = { + val unserializableUDT = child.schema.exists(_.dataType match { + case _: UserDefinedType[_] => true + case _ => false + }) // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. - !newPartitioning.isInstanceOf[RangePartitioning] + !unserializableUDT && !newPartitioning.isInstanceOf[RangePartitioning] } + override def outputPartitioning: Partitioning = newPartitioning + + override def output: Seq[Attribute] = child.output + + // This setting is somewhat counterintuitive: + // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, + // so the planner inserts a converter to convert data into UnsafeRow if needed. + override def outputsUnsafeRows: Boolean = tungstenMode + override def canProcessSafeRows: Boolean = !tungstenMode + override def canProcessUnsafeRows: Boolean = tungstenMode + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -124,23 +138,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val serializer: Serializer = { val rowDataTypes = child.output.map(_.dataType).toArray - // It is true when there is no field that needs to be write out. - // For now, we will not use SparkSqlSerializer2 when noField is true. - val noField = rowDataTypes == null || rowDataTypes.length == 0 - - val useSqlSerializer2 = - child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. - SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported. - !noField - - if (child.outputsUnsafeRows) { - logInfo("Using UnsafeRowSerializer.") + if (tungstenMode) { new UnsafeRowSerializer(child.output.size) - } else if (useSqlSerializer2) { - logInfo("Using SparkSqlSerializer2.") - new SparkSqlSerializer2(rowDataTypes) } else { - logInfo("Using SparkSqlSerializer.") new SparkSqlSerializer(sparkConf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala deleted file mode 100644 index e811f1de3e6d..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ /dev/null @@ -1,426 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.io._ -import java.math.{BigDecimal, BigInteger} -import java.nio.ByteBuffer - -import scala.reflect.ClassTag - -import org.apache.spark.Logging -import org.apache.spark.serializer._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in - * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the - * [[Product2]] are constructed based on their schemata. - * The benefit of this serialization stream is that compared with general-purpose serializers like - * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower - * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are: - * 1. It does not support complex types, i.e. Map, Array, and Struct. - * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when - * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because - * the objects passed in the serializer are not in the type of [[Product2]]. Also also see - * the comment of the `serializer` method in [[Exchange]] for more information on it. - */ -private[sql] class Serializer2SerializationStream( - rowSchema: Array[DataType], - out: OutputStream) - extends SerializationStream with Logging { - - private val rowOut = new DataOutputStream(new BufferedOutputStream(out)) - private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) - - override def writeObject[T: ClassTag](t: T): SerializationStream = { - val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]] - writeKey(kv._1) - writeValue(kv._2) - - this - } - - override def writeKey[T: ClassTag](t: T): SerializationStream = { - // No-op. - this - } - - override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeRowFunc(t.asInstanceOf[InternalRow]) - this - } - - def flush(): Unit = { - rowOut.flush() - } - - def close(): Unit = { - rowOut.close() - } -} - -/** - * The corresponding deserialization stream for [[Serializer2SerializationStream]]. - */ -private[sql] class Serializer2DeserializationStream( - rowSchema: Array[DataType], - in: InputStream) - extends DeserializationStream with Logging { - - private val rowIn = new DataInputStream(new BufferedInputStream(in)) - - private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = { - if (schema == null) { - () => null - } else { - // It is safe to reuse the mutable row. - val mutableRow = new SpecificMutableRow(schema) - () => mutableRow - } - } - - // Functions used to return rows for key and value. - private val getRow = rowGenerator(rowSchema) - // Functions used to read a serialized row from the InputStream and deserialize it. - private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn) - - override def readObject[T: ClassTag](): T = { - readValue() - } - - override def readKey[T: ClassTag](): T = { - null.asInstanceOf[T] // intentionally left blank. - } - - override def readValue[T: ClassTag](): T = { - readRowFunc(getRow()).asInstanceOf[T] - } - - override def close(): Unit = { - rowIn.close() - } -} - -private[sql] class SparkSqlSerializer2Instance( - rowSchema: Array[DataType]) - extends SerializerInstance { - - def serialize[T: ClassTag](t: T): ByteBuffer = - throw new UnsupportedOperationException("Not supported.") - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException("Not supported.") - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException("Not supported.") - - def serializeStream(s: OutputStream): SerializationStream = { - new Serializer2SerializationStream(rowSchema, s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new Serializer2DeserializationStream(rowSchema, s) - } -} - -/** - * SparkSqlSerializer2 is a special serializer that creates serialization function and - * deserialization function based on the schema of data. It assumes that values passed in - * are Rows. - */ -private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType]) - extends Serializer - with Logging - with Serializable{ - - def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema) - - override def supportsRelocationOfSerializedObjects: Boolean = { - // SparkSqlSerializer2 is stateless and writes no stream headers - true - } -} - -private[sql] object SparkSqlSerializer2 { - - final val NULL = 0 - final val NOT_NULL = 1 - - /** - * Check if rows with the given schema can be serialized with ShuffleSerializer. - * Right now, we do not support a schema having complex types or UDTs, or all data types - * of fields are NullTypes. - */ - def support(schema: Array[DataType]): Boolean = { - if (schema == null) return true - - var allNullTypes = true - var i = 0 - while (i < schema.length) { - schema(i) match { - case NullType => // Do nothing - case udt: UserDefinedType[_] => - allNullTypes = false - return false - case array: ArrayType => - allNullTypes = false - return false - case map: MapType => - allNullTypes = false - return false - case struct: StructType => - allNullTypes = false - return false - case _ => - allNullTypes = false - } - i += 1 - } - - // If types of fields are all NullTypes, we return false. - // Otherwise, we return true. - return !allNullTypes - } - - /** - * The util function to create the serialization function based on the given schema. - */ - def createSerializationFunction(schema: Array[DataType], out: DataOutputStream) - : InternalRow => Unit = { - (row: InternalRow) => - // If the schema is null, the returned function does nothing when it get called. - if (schema != null) { - var i = 0 - while (i < schema.length) { - schema(i) match { - // When we write values to the underlying stream, we also first write the null byte - // first. Then, if the value is not null, we write the contents out. - - case NullType => // Write nothing. - - case BooleanType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeBoolean(row.getBoolean(i)) - } - - case ByteType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeByte(row.getByte(i)) - } - - case ShortType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeShort(row.getShort(i)) - } - - case IntegerType | DateType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeInt(row.getInt(i)) - } - - case LongType | TimestampType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeLong(row.getLong(i)) - } - - case FloatType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeFloat(row.getFloat(i)) - } - - case DoubleType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeDouble(row.getDouble(i)) - } - - case StringType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val bytes = row.getUTF8String(i).getBytes - out.writeInt(bytes.length) - out.write(bytes) - } - - case BinaryType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val bytes = row.getBinary(i) - out.writeInt(bytes.length) - out.write(bytes) - } - - case decimal: DecimalType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - val value = row.getDecimal(i, decimal.precision, decimal.scale) - val javaBigDecimal = value.toJavaBigDecimal - // First, write out the unscaled value. - val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray - out.writeInt(bytes.length) - out.write(bytes) - // Then, write out the scale. - out.writeInt(javaBigDecimal.scale()) - } - } - i += 1 - } - } - } - - /** - * The util function to create the deserialization function based on the given schema. - */ - def createDeserializationFunction( - schema: Array[DataType], - in: DataInputStream): (MutableRow) => InternalRow = { - if (schema == null) { - (mutableRow: MutableRow) => null - } else { - (mutableRow: MutableRow) => { - var i = 0 - while (i < schema.length) { - schema(i) match { - // When we read values from the underlying stream, we also first read the null byte - // first. Then, if the value is not null, we update the field of the mutable row. - - case NullType => mutableRow.setNullAt(i) // Read nothing. - - case BooleanType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setBoolean(i, in.readBoolean()) - } - - case ByteType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setByte(i, in.readByte()) - } - - case ShortType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setShort(i, in.readShort()) - } - - case IntegerType | DateType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setInt(i, in.readInt()) - } - - case LongType | TimestampType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setLong(i, in.readLong()) - } - - case FloatType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setFloat(i, in.readFloat()) - } - - case DoubleType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.setDouble(i, in.readDouble()) - } - - case StringType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - mutableRow.update(i, UTF8String.fromBytes(bytes)) - } - - case BinaryType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - mutableRow.update(i, bytes) - } - - case decimal: DecimalType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - // First, read in the unscaled value. - val length = in.readInt() - val bytes = new Array[Byte](length) - in.readFully(bytes) - val unscaledVal = new BigInteger(bytes) - // Then, read the scale. - val scale = in.readInt() - // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, - Decimal(new BigDecimal(unscaledVal, scale), decimal.precision, decimal.scale)) - } - } - i += 1 - } - - mutableRow - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala deleted file mode 100644 index 7978ed57a937..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.sql.{Timestamp, Date} - -import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.serializer.Serializer -import org.apache.spark.{ShuffleDependency, SparkFunSuite} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row -import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} - -class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { - // Make sure that we will not use serializer2 for unsupported data types. - def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { - val testName = - s"${if (dataType == null) null else dataType.toString} is " + - s"${if (isSupported) "supported" else "unsupported"}" - - test(testName) { - assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) - } - } - - checkSupported(null, isSupported = true) - checkSupported(BooleanType, isSupported = true) - checkSupported(ByteType, isSupported = true) - checkSupported(ShortType, isSupported = true) - checkSupported(IntegerType, isSupported = true) - checkSupported(LongType, isSupported = true) - checkSupported(FloatType, isSupported = true) - checkSupported(DoubleType, isSupported = true) - checkSupported(DateType, isSupported = true) - checkSupported(TimestampType, isSupported = true) - checkSupported(StringType, isSupported = true) - checkSupported(BinaryType, isSupported = true) - checkSupported(DecimalType(10, 5), isSupported = true) - checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true) - - // If NullType is the only data type in the schema, we do not support it. - checkSupported(NullType, isSupported = false) - // For now, ArrayType, MapType, and StructType are not supported. - checkSupported(ArrayType(DoubleType, true), isSupported = false) - checkSupported(ArrayType(StringType, false), isSupported = false) - checkSupported(MapType(IntegerType, StringType, true), isSupported = false) - checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) - checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) - // UDTs are not supported right now. - checkSupported(new MyDenseVectorUDT, isSupported = false) -} - -abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { - var allColumns: String = _ - val serializerClass: Class[Serializer] = - classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] - var numShufflePartitions: Int = _ - var useSerializer2: Boolean = _ - - protected lazy val ctx = TestSQLContext - - override def beforeAll(): Unit = { - numShufflePartitions = ctx.conf.numShufflePartitions - useSerializer2 = ctx.conf.useSqlSerializer2 - - ctx.sql("set spark.sql.useSerializer2=true") - - val supportedTypes = - Seq(StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), - DateType, TimestampType) - - val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, true) - } - allColumns = fields.map(_.name).mkString(",") - val schema = StructType(fields) - - // Create a RDD with all data types supported by SparkSqlSerializer2. - val rdd = - ctx.sparkContext.parallelize((1 to 1000), 10).map { i => - Row( - s"str${i}: test serializer2.", - s"binary${i}: test serializer2.".getBytes("UTF-8"), - null, - i % 2 == 0, - i.toByte, - i.toShort, - i, - Long.MaxValue - i.toLong, - (i + 0.25).toFloat, - (i + 0.75), - BigDecimal(Long.MaxValue.toString + ".12345"), - new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), - new Date(i), - new Timestamp(i)) - } - - ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") - - super.beforeAll() - } - - override def afterAll(): Unit = { - ctx.dropTempTable("shuffle") - ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") - ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2") - super.afterAll() - } - - def checkSerializer[T <: Serializer]( - executedPlan: SparkPlan, - expectedSerializerClass: Class[T]): Unit = { - executedPlan.foreach { - case exchange: Exchange => - val shuffledRDD = exchange.execute() - val dependency = shuffledRDD.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] - val serializerNotSetMessage = - s"Expected $expectedSerializerClass as the serializer of Exchange. " + - s"However, the serializer was not set." - val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) - val isExpectedSerializer = - serializer.getClass == expectedSerializerClass || - serializer.getClass == classOf[UnsafeRowSerializer] - val wrongSerializerErrorMessage = - s"Expected ${expectedSerializerClass.getCanonicalName} or " + - s"${classOf[UnsafeRowSerializer].getCanonicalName}. But " + - s"${serializer.getClass.getCanonicalName} is used." - assert(isExpectedSerializer, wrongSerializerErrorMessage) - case _ => // Ignore other nodes. - } - } - - test("key schema and value schema are not nulls") { - val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - ctx.table("shuffle").collect()) - } - - test("key schema is null") { - val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = ctx.sql(s"SELECT $aggregations FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - checkAnswer( - df, - Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) - } - - test("value schema is null") { - val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0") - checkSerializer(df.queryExecution.executedPlan, serializerClass) - assert(df.map(r => r.getString(0)).collect().toSeq === - ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) - } - - test("no map output field") { - val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") - checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - } - - test("types of fields are all NullTypes") { - // Test range partitioning code path. - val nulls = ctx.sql(s"SELECT null as a, null as b, null as c") - val df = nulls.unionAll(nulls).sort("a") - checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - checkAnswer( - df, - Row(null, null, null) :: Row(null, null, null) :: Nil) - - // Test hash partitioning code path. - val oneRow = ctx.sql(s"SELECT DISTINCT null, null, null FROM shuffle") - checkSerializer(oneRow.queryExecution.executedPlan, classOf[SparkSqlSerializer]) - checkAnswer( - oneRow, - Row(null, null, null)) - } -} - -/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ -class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { - override def beforeAll(): Unit = { - super.beforeAll() - // Sort merge will not be triggered. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") - } -} - -/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ -class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { - - override def beforeAll(): Unit = { - super.beforeAll() - // To trigger the sort merge. - val bypassMergeThreshold = - ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") - } -} From 2432c2e239f66049a7a7d7e0591204abcc993f1a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 7 Aug 2015 11:28:43 -0700 Subject: [PATCH 0117/1824] [SPARK-8382] [SQL] Improve Analysis Unit test framework Author: Wenchen Fan Closes #8025 from cloud-fan/analysis and squashes the following commits: 51461b1 [Wenchen Fan] move test file to test folder ec88ace [Wenchen Fan] Improve Analysis Unit test framework --- .../analysis/AnalysisErrorSuite.scala | 48 +++++----------- .../sql/catalyst/analysis/AnalysisSuite.scala | 55 +------------------ .../sql/catalyst/analysis/AnalysisTest.scala | 33 +---------- .../sql/catalyst/analysis/TestRelations.scala | 51 +++++++++++++++++ .../BooleanSimplificationSuite.scala | 19 ++++--- 5 files changed, 79 insertions(+), 127 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 26935c6e3b24..63b475b6366c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -42,8 +42,8 @@ case class UnresolvedTestPlan() extends LeafNode { override def output: Seq[Attribute] = Nil } -class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { - import AnalysisSuite._ +class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter { + import TestRelations._ def errorTest( name: String, @@ -51,15 +51,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { errorMessages: Seq[String], caseSensitive: Boolean = true): Unit = { test(name) { - val error = intercept[AnalysisException] { - if (caseSensitive) { - caseSensitiveAnalyze(plan) - } else { - caseInsensitiveAnalyze(plan) - } - } - - errorMessages.foreach(m => assert(error.getMessage.toLowerCase.contains(m.toLowerCase))) + assertAnalysisError(plan, errorMessages, caseSensitive) } } @@ -69,21 +61,21 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { "single invalid type, single arg", testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" :: - "'null' is of date type" ::Nil) + "'null' is of date type" :: Nil) errorTest( "single invalid type, second arg", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" :: - "'null' is of date type" ::Nil) + "'null' is of date type" :: Nil) errorTest( "multiple invalid type", testRelation.select( TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: - "requires int type" :: "'null' is of date type" ::Nil) + "requires int type" :: "'null' is of date type" :: Nil) errorTest( "unresolved window function", @@ -169,11 +161,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { assert(plan.resolved) - val message = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - }.getMessage - - assert(message.contains("resolved attribute(s) a#1 missing from a#2")) + assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil) } test("error test for self-join") { @@ -194,10 +182,8 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - val error = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - } - assert(error.message.contains("binary type expression a cannot be used in grouping expression")) + assertAnalysisError(plan, + "binary type expression a cannot be used in grouping expression" :: Nil) val plan2 = Aggregate( @@ -207,10 +193,8 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - val error2 = intercept[AnalysisException] { - caseSensitiveAnalyze(plan2) - } - assert(error2.message.contains("map type expression a cannot be used in grouping expression")) + assertAnalysisError(plan2, + "map type expression a cannot be used in grouping expression" :: Nil) } test("Join can't work on binary and map types") { @@ -226,10 +210,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("c", BinaryType)(exprId = ExprId(4))))) - val error = intercept[AnalysisException] { - caseSensitiveAnalyze(plan) - } - assert(error.message.contains("binary type expression a cannot be used in join conditions")) + assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil) val plan2 = Join( @@ -243,9 +224,6 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) - val error2 = intercept[AnalysisException] { - caseSensitiveAnalyze(plan2) - } - assert(error2.message.contains("map type expression a cannot be used in join conditions")) + assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 221b4e92f086..c944bc69e25b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -24,61 +24,8 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -// todo: remove this and use AnalysisTest instead. -object AnalysisSuite { - val caseSensitiveConf = new SimpleCatalystConf(true) - val caseInsensitiveConf = new SimpleCatalystConf(false) - - val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) - val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) - - val caseSensitiveAnalyzer = - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil - } - val caseInsensitiveAnalyzer = - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { - override val extendedResolutionRules = EliminateSubQueries :: Nil - } - - def caseSensitiveAnalyze(plan: LogicalPlan): Unit = - caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer.execute(plan)) - - def caseInsensitiveAnalyze(plan: LogicalPlan): Unit = - caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer.execute(plan)) - - val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - - val nestedRelation = LocalRelation( - AttributeReference("top", StructType( - StructField("duplicateField", StringType) :: - StructField("duplicateField", StringType) :: - StructField("differentCase", StringType) :: - StructField("differentcase", StringType) :: Nil - ))()) - - val nestedRelation2 = LocalRelation( - AttributeReference("top", StructType( - StructField("aField", StringType) :: - StructField("bField", StringType) :: - StructField("cField", StringType) :: Nil - ))()) - - val listRelation = LocalRelation( - AttributeReference("list", ArrayType(IntegerType))()) - - caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) -} - - class AnalysisSuite extends AnalysisTest { + import TestRelations._ test("union project *") { val plan = (1 to 100) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index fdb4f28950da..ee1f8f54251e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,40 +17,11 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SimpleCatalystConf -import org.apache.spark.sql.types._ trait AnalysisTest extends PlanTest { - val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) - - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - - val nestedRelation = LocalRelation( - AttributeReference("top", StructType( - StructField("duplicateField", StringType) :: - StructField("duplicateField", StringType) :: - StructField("differentCase", StringType) :: - StructField("differentcase", StringType) :: Nil - ))()) - - val nestedRelation2 = LocalRelation( - AttributeReference("top", StructType( - StructField("aField", StringType) :: - StructField("bField", StringType) :: - StructField("cField", StringType) :: Nil - ))()) - - val listRelation = LocalRelation( - AttributeReference("list", ArrayType(IntegerType))()) val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { val caseSensitiveConf = new SimpleCatalystConf(true) @@ -59,8 +30,8 @@ trait AnalysisTest extends PlanTest { val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) - caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) - caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseSensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation) new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { override val extendedResolutionRules = EliminateSubQueries :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala new file mode 100644 index 000000000000..05b870705e7e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types._ + +object TestRelations { + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(10, 2))(), + AttributeReference("e", ShortType)()) + + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index d4916ea8d273..1877cff1334b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{AnalysisSuite, EliminateSubQueries} +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -88,20 +89,24 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) } - private def caseInsensitiveAnalyse(plan: LogicalPlan) = - AnalysisSuite.caseInsensitiveAnalyzer.execute(plan) + private val caseInsensitiveAnalyzer = + new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false)) test("(a && b) || (a && c) => a && (b || c) when case insensitive") { - val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) + val plan = caseInsensitiveAnalyzer.execute( + testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) val actual = Optimize.execute(plan) - val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 && ('b > 3 || 'b < 5))) + val expected = caseInsensitiveAnalyzer.execute( + testRelation.where('a > 2 && ('b > 3 || 'b < 5))) comparePlans(actual, expected) } test("(a || b) && (a || c) => a || (b && c) when case insensitive") { - val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) + val plan = caseInsensitiveAnalyzer.execute( + testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) val actual = Optimize.execute(plan) - val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 || ('b > 3 && 'b < 5))) + val expected = caseInsensitiveAnalyzer.execute( + testRelation.where('a > 2 || ('b > 3 && 'b < 5))) comparePlans(actual, expected) } } From 9897cc5e3d6c70f7e45e887e2c6fc24dfa1adada Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 11:29:13 -0700 Subject: [PATCH 0118/1824] [SPARK-9736] [SQL] JoinedRow.anyNull should delegate to the underlying rows. JoinedRow.anyNull currently loops through every field to check for null, which is inefficient if the underlying rows are UnsafeRows. It should just delegate to the underlying implementation. Author: Reynold Xin Closes #8027 from rxin/SPARK-9736 and squashes the following commits: 03a2e92 [Reynold Xin] Include all files. 90f1add [Reynold Xin] [SPARK-9736][SQL] JoinedRow.anyNull should delegate to the underlying rows. --- .../spark/sql/catalyst/InternalRow.scala | 10 +- .../sql/catalyst/expressions/JoinedRow.scala | 144 ++++++++++++++++++ .../sql/catalyst/expressions/Projection.scala | 119 --------------- .../spark/sql/catalyst/expressions/rows.scala | 12 +- 4 files changed, 156 insertions(+), 129 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 85b4bf3b6aef..eba95c5c8b90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -37,15 +37,7 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def copy(): InternalRow /** Returns true if there are any NULL values in this row. */ - def anyNull: Boolean = { - val len = numFields - var i = 0 - while (i < len) { - if (isNullAt(i)) { return true } - i += 1 - } - false - } + def anyNull: Boolean /* ---------------------- utility methods for Scala ---------------------- */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala new file mode 100644 index 000000000000..b76757c93523 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + + +/** + * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to + * be instantiated once per thread and reused. + */ +class JoinedRow extends InternalRow { + private[this] var row1: InternalRow = _ + private[this] var row2: InternalRow = _ + + def this(left: InternalRow, right: InternalRow) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: InternalRow, r2: InternalRow): InternalRow = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: InternalRow): InternalRow = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: InternalRow): InternalRow = { + row2 = newRight + this + } + + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { + assert(fieldTypes.length == row1.numFields + row2.numFields) + val (left, right) = fieldTypes.splitAt(row1.numFields) + row1.toSeq(left) ++ row2.toSeq(right) + } + + override def numFields: Int = row1.numFields + row2.numFields + + override def get(i: Int, dt: DataType): AnyRef = + if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt) + + override def isNullAt(i: Int): Boolean = + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) + + override def getBoolean(i: Int): Boolean = + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) + + override def getByte(i: Int): Byte = + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) + + override def getShort(i: Int): Short = + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) + + override def getInt(i: Int): Int = + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) + + override def getLong(i: Int): Long = + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) + + override def getFloat(i: Int): Float = + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + + override def getDouble(i: Int): Double = + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) + + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + if (i < row1.numFields) { + row1.getDecimal(i, precision, scale) + } else { + row2.getDecimal(i - row1.numFields, precision, scale) + } + } + + override def getUTF8String(i: Int): UTF8String = + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) + + override def getBinary(i: Int): Array[Byte] = + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) + + override def getArray(i: Int): ArrayData = + if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields) + + override def getInterval(i: Int): CalendarInterval = + if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) + + override def getMap(i: Int): MapData = + if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields) + + override def getStruct(i: Int, numFields: Int): InternalRow = { + if (i < row1.numFields) { + row1.getStruct(i, numFields) + } else { + row2.getStruct(i - row1.numFields, numFields) + } + } + + override def anyNull: Boolean = row1.anyNull || row2.anyNull + + override def copy(): InternalRow = { + val copy1 = row1.copy() + val copy2 = row2.copy() + new JoinedRow(copy1, copy2) + } + + override def toString: String = { + // Make sure toString never throws NullPointerException. + if ((row1 eq null) && (row2 eq null)) { + "[ empty row ]" + } else if (row1 eq null) { + row2.toString + } else if (row2 eq null) { + row1.toString + } else { + s"{${row1.toString} + ${row2.toString}}" + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 59ce7fc4f2c6..796bc327a3db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -169,122 +169,3 @@ object FromUnsafeProjection { GenerateSafeProjection.generate(exprs) } } - -/** - * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to - * be instantiated once per thread and reused. - */ -class JoinedRow extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = { - assert(fieldTypes.length == row1.numFields + row2.numFields) - val (left, right) = fieldTypes.splitAt(row1.numFields) - row1.toSeq(left) ++ row2.toSeq(right) - } - - override def numFields: Int = row1.numFields + row2.numFields - - override def get(i: Int, dt: DataType): AnyRef = - if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt) - - override def isNullAt(i: Int): Boolean = - if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) - - override def getBoolean(i: Int): Boolean = - if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) - - override def getByte(i: Int): Byte = - if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) - - override def getShort(i: Int): Short = - if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) - - override def getInt(i: Int): Int = - if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) - - override def getLong(i: Int): Long = - if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) - - override def getFloat(i: Int): Float = - if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) - - override def getDouble(i: Int): Double = - if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) - - override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { - if (i < row1.numFields) { - row1.getDecimal(i, precision, scale) - } else { - row2.getDecimal(i - row1.numFields, precision, scale) - } - } - - override def getUTF8String(i: Int): UTF8String = - if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - - override def getBinary(i: Int): Array[Byte] = - if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - - override def getArray(i: Int): ArrayData = - if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields) - - override def getInterval(i: Int): CalendarInterval = - if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) - - override def getMap(i: Int): MapData = - if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields) - - override def getStruct(i: Int, numFields: Int): InternalRow = { - if (i < row1.numFields) { - row1.getStruct(i, numFields) - } else { - row2.getStruct(i - row1.numFields, numFields) - } - } - - override def copy(): InternalRow = { - val copy1 = row1.copy() - val copy2 = row2.copy() - new JoinedRow(copy1, copy2) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.toString - } else if (row2 eq null) { - row1.toString - } else { - s"{${row1.toString} + ${row2.toString}}" - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 11d10b2d8a48..017efd2a166a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -49,7 +49,17 @@ trait BaseGenericInternalRow extends InternalRow { override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - override def toString(): String = { + override def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } + + override def toString: String = { if (numFields == 0) { "[empty row]" } else { From aeddeafc03d77a5149d2c8f9489b0ca83e6b3e03 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 13:26:03 -0700 Subject: [PATCH 0119/1824] [SPARK-9667][SQL] followup: Use GenerateUnsafeProjection.canSupport to test Exchange supported data types. This way we recursively test the data types. cc chenghao-intel Author: Reynold Xin Closes #8036 from rxin/cansupport and squashes the following commits: f7302ff [Reynold Xin] Can GenerateUnsafeProjection.canSupport to test Exchange supported data types. --- .../org/apache/spark/sql/execution/Exchange.scala | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 60087f2ca4a3..49bb72980086 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} @@ -43,18 +43,11 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange" /** - * Returns true iff the children outputs aggregate UDTs that are not part of the SQL type. - * This only happens with the old aggregate implementation and should be removed in 1.6. + * Returns true iff we can support the data type, and we are not doing range partitioning. */ private lazy val tungstenMode: Boolean = { - val unserializableUDT = child.schema.exists(_.dataType match { - case _: UserDefinedType[_] => true - case _ => false - }) - // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to - // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to - // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. - !unserializableUDT && !newPartitioning.isInstanceOf[RangePartitioning] + GenerateUnsafeProjection.canSupport(child.schema) && + !newPartitioning.isInstanceOf[RangePartitioning] } override def outputPartitioning: Partitioning = newPartitioning From 05d04e10a8ea030bea840c3c5ba93ecac479a039 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 13:41:45 -0700 Subject: [PATCH 0120/1824] [SPARK-9733][SQL] Improve physical plan explain for data sources All data sources show up as "PhysicalRDD" in physical plan explain. It'd be better if we can show the name of the data source. Without this patch: ``` == Physical Plan == NewAggregate with UnsafeHybridAggregationIterator ArrayBuffer(date#0, cat#1) ArrayBuffer((sum(CAST((CAST(count#2, IntegerType) + 1), LongType))2,mode=Final,isDistinct=false)) Exchange hashpartitioning(date#0,cat#1) NewAggregate with UnsafeHybridAggregationIterator ArrayBuffer(date#0, cat#1) ArrayBuffer((sum(CAST((CAST(count#2, IntegerType) + 1), LongType))2,mode=Partial,isDistinct=false)) PhysicalRDD [date#0,cat#1,count#2], MapPartitionsRDD[3] at ``` With this patch: ``` == Physical Plan == TungstenAggregate(key=[date#0,cat#1], value=[(sum(CAST((CAST(count#2, IntegerType) + 1), LongType)),mode=Final,isDistinct=false)] Exchange hashpartitioning(date#0,cat#1) TungstenAggregate(key=[date#0,cat#1], value=[(sum(CAST((CAST(count#2, IntegerType) + 1), LongType)),mode=Partial,isDistinct=false)] ConvertToUnsafe Scan ParquetRelation[file:/scratch/rxin/spark/sales4][date#0,cat#1,count#2] ``` Author: Reynold Xin Closes #8024 from rxin/SPARK-9733 and squashes the following commits: 811b90e [Reynold Xin] Fixed Python test case. 52cab77 [Reynold Xin] Cast. eea9ccc [Reynold Xin] Fix test case. fcecb22 [Reynold Xin] [SPARK-9733][SQL] Improve explain message for data source scan node. --- python/pyspark/sql/dataframe.py | 4 +--- .../spark/sql/catalyst/expressions/Cast.scala | 4 ++-- .../expressions/aggregate/interfaces.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 4 ---- .../spark/sql/execution/ExistingRDD.scala | 15 ++++++++++++- .../spark/sql/execution/SparkStrategies.scala | 4 ++-- .../aggregate/TungstenAggregate.scala | 9 +++++--- .../datasources/DataSourceStrategy.scala | 22 +++++++++++++------ .../apache/spark/sql/sources/interfaces.scala | 2 +- .../execution/RowFormatConvertersSuite.scala | 4 ++-- .../sql/hive/execution/HiveExplainSuite.scala | 2 +- 11 files changed, 45 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0f3480c23918..47d5a6a43a84 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -212,8 +212,7 @@ def explain(self, extended=False): :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() - PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at applySchemaToPythonRDD at\ - NativeMethodAccessorImpl.java:... + Scan PhysicalRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == @@ -224,7 +223,6 @@ def explain(self, extended=False): ... == Physical Plan == ... - == RDD == """ if extended: print(self._jdf.queryExecution().toString()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 39f99700c8a2..946c5a9c04f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -107,6 +107,8 @@ object Cast { case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with CodegenFallback { + override def toString: String = s"cast($child as ${dataType.simpleString})" + override def checkInputDataTypes(): TypeCheckResult = { if (Cast.canCast(child.dataType, dataType)) { TypeCheckResult.TypeCheckSuccess @@ -118,8 +120,6 @@ case class Cast(child: Expression, dataType: DataType) override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable - override def toString: String = s"CAST($child, $dataType)" - // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 4abfdfe87d5e..576d8c7a3a68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -93,7 +93,7 @@ private[sql] case class AggregateExpression2( AttributeSet(childReferences) } - override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" + override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" } abstract class AggregateFunction2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 075c0ea2544b..832572571cab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1011,9 +1011,6 @@ class SQLContext(@transient val sparkContext: SparkContext) def output = analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") - // TODO previously will output RDD details by run (${stringOrError(toRdd.toDebugString)}) - // however, the `toRdd` will cause the real execution, which is not what we want. - // We need to think about how to avoid the side effect. s"""== Parsed Logical Plan == |${stringOrError(logical)} |== Analyzed Logical Plan == @@ -1024,7 +1021,6 @@ class SQLContext(@transient val sparkContext: SparkContext) |== Physical Plan == |${stringOrError(executedPlan)} |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} - |== RDD == """.stripMargin.trim } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index fbaa8e276ddb..cae7ca5cbdc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} @@ -95,11 +96,23 @@ private[sql] case class LogicalRDD( /** Physical plan node for scanning data from an RDD. */ private[sql] case class PhysicalRDD( output: Seq[Attribute], - rdd: RDD[InternalRow]) extends LeafNode { + rdd: RDD[InternalRow], + extraInformation: String) extends LeafNode { override protected[sql] val trackNumOfRowsEnabled = true protected override def doExecute(): RDD[InternalRow] = rdd + + override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") +} + +private[sql] object PhysicalRDD { + def createFromDataSource( + output: Seq[Attribute], + rdd: RDD[InternalRow], + relation: BaseRelation): PhysicalRDD = { + PhysicalRDD(output, rdd, relation.toString) + } } /** Logical plan node for scanning data from a local collection. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c5aaebe67322..c4b9b5acea4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -363,12 +363,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Generate( generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => - execution.PhysicalRDD(Nil, singleRowRdd) :: Nil + execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil case logical.RepartitionByExpression(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil case BroadcastHint(child) => apply(child) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 5a0b4d47d62f..c3dcbd2b71ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -93,10 +93,13 @@ case class TungstenAggregate( val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions testFallbackStartsAt match { - case None => s"TungstenAggregate ${groupingExpressions} ${allAggregateExpressions}" + case None => + val keyString = groupingExpressions.mkString("[", ",", "]") + val valueString = allAggregateExpressions.mkString("[", ",", "]") + s"TungstenAggregate(key=$keyString, value=$valueString" case Some(fallbackStartsAt) => - s"TungstenAggregateWithControlledFallback ${groupingExpressions} " + - s"${allAggregateExpressions} fallbackStartsAt=$fallbackStartsAt" + s"TungstenAggregateWithControlledFallback $groupingExpressions " + + s"$allAggregateExpressions fallbackStartsAt=$fallbackStartsAt" } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e5dc676b8784..5b5fa8c93ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -101,8 +101,9 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f, t.paths, confBroadcast))) :: Nil - case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, toCatalystRDD(l, t.buildScan())) :: Nil + case l @ LogicalRelation(baseRelation: TableScan) => + execution.PhysicalRDD.createFromDataSource( + l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => @@ -169,7 +170,10 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) } - execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows) + execution.PhysicalRDD.createFromDataSource( + projections.map(_.toAttribute), + unionedRows, + logicalRelation.relation) } // TODO: refactor this thing. It is very complicated because it does projection internally. @@ -299,14 +303,18 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { projects.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = execution.PhysicalRDD(projects.map(_.toAttribute), - scanBuilder(requestedColumns, pushedFilters)) + val scan = execution.PhysicalRDD.createFromDataSource( + projects.map(_.toAttribute), + scanBuilder(requestedColumns, pushedFilters), + relation.relation) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = execution.PhysicalRDD(requestedColumns, - scanBuilder(requestedColumns, pushedFilters)) + val scan = execution.PhysicalRDD.createFromDataSource( + requestedColumns, + scanBuilder(requestedColumns, pushedFilters), + relation.relation) execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c04557e5a081..0b2929661b65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -383,7 +383,7 @@ private[sql] abstract class OutputWriterInternal extends OutputWriter { abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) extends BaseRelation with Logging { - logInfo("Constructing HadoopFsRelation") + override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") def this() = this(None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 8208b25b5708..322966f42378 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -32,9 +32,9 @@ class RowFormatConvertersSuite extends SparkPlanTest { case c: ConvertToSafe => c } - private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 697211222b90..8215dd6c2e71 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -36,7 +36,7 @@ class HiveExplainSuite extends QueryTest { "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", "== Physical Plan ==", - "Code Generation", "== RDD ==") + "Code Generation") } test("explain create table command") { From 881548ab20fa4c4b635c51d956b14bd13981e2f4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 7 Aug 2015 14:20:13 -0700 Subject: [PATCH 0121/1824] [SPARK-9674] Re-enable ignored test in SQLQuerySuite The original code that this test tests is removed in https://github.com/apache/spark/commit/9270bd06fd0b16892e3f37213b5bc7813ea11fdd. It was ignored shortly before that so we never caught it. This patch re-enables the test and adds the code necessary to make it pass. JoshRosen yhuai Author: Andrew Or Closes #8015 from andrewor14/SPARK-9674 and squashes the following commits: 225eac2 [Andrew Or] Merge branch 'master' of github.com:apache/spark into SPARK-9674 8c24209 [Andrew Or] Fix NPE e541d64 [Andrew Or] Track aggregation memory for both sort and hash 0be3a42 [Andrew Or] Fix test --- .../spark/unsafe/map/BytesToBytesMap.java | 37 +++++++++++++++++-- .../map/AbstractBytesToBytesMapSuite.java | 20 ++++++---- .../UnsafeFixedWidthAggregationMap.java | 7 ++-- .../sql/execution/UnsafeKVExternalSorter.java | 7 ++++ .../TungstenAggregationIterator.scala | 32 +++++++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 ++-- 6 files changed, 85 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 0636ae7c8df1..7f79cd13aab4 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -109,7 +109,7 @@ public final class BytesToBytesMap { * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. */ - private LongArray longArray; + @Nullable private LongArray longArray; // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode // and exploit word-alignment to use fewer bits to hold the address. This might let us store // only one long per map entry, increasing the chance that this array will fit in cache at the @@ -124,7 +124,7 @@ public final class BytesToBytesMap { * A {@link BitSet} used to track location of the map where the key is set. * Size of the bitset should be half of the size of the long array. */ - private BitSet bitset; + @Nullable private BitSet bitset; private final double loadFactor; @@ -166,6 +166,8 @@ public final class BytesToBytesMap { private long numHashCollisions = 0; + private long peakMemoryUsedBytes = 0L; + public BytesToBytesMap( TaskMemoryManager taskMemoryManager, ShuffleMemoryManager shuffleMemoryManager, @@ -321,6 +323,9 @@ public Location lookup( Object keyBaseObject, long keyBaseOffset, int keyRowLengthBytes) { + assert(bitset != null); + assert(longArray != null); + if (enablePerfMetrics) { numKeyLookups++; } @@ -410,6 +415,7 @@ private void updateAddressesAndSizes(final Object page, final long offsetInPage) } private Location with(int pos, int keyHashcode, boolean isDefined) { + assert(longArray != null); this.pos = pos; this.isDefined = isDefined; this.keyHashcode = keyHashcode; @@ -525,6 +531,9 @@ public boolean putNewKey( assert (!isDefined) : "Can only set value once for a key"; assert (keyLengthBytes % 8 == 0); assert (valueLengthBytes % 8 == 0); + assert(bitset != null); + assert(longArray != null); + if (numElements == MAX_CAPACITY) { throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); } @@ -658,6 +667,7 @@ private void allocate(int capacity) { * This method is idempotent and can be called multiple times. */ public void free() { + updatePeakMemoryUsed(); longArray = null; bitset = null; Iterator dataPagesIterator = dataPages.iterator(); @@ -684,14 +694,30 @@ public long getPageSizeBytes() { /** * Returns the total amount of memory, in bytes, consumed by this map's managed structures. - * Note that this is also the peak memory used by this map, since the map is append-only. */ public long getTotalMemoryConsumption() { long totalDataPagesSize = 0L; for (MemoryBlock dataPage : dataPages) { totalDataPagesSize += dataPage.size(); } - return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size(); + return totalDataPagesSize + + ((bitset != null) ? bitset.memoryBlock().size() : 0L) + + ((longArray != null) ? longArray.memoryBlock().size() : 0L); + } + + private void updatePeakMemoryUsed() { + long mem = getTotalMemoryConsumption(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; } /** @@ -731,6 +757,9 @@ int getNumDataPages() { */ @VisibleForTesting void growAndRehash() { + assert(bitset != null); + assert(longArray != null); + long resizeStartTime = -1; if (enablePerfMetrics) { resizeStartTime = System.nanoTime(); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 0b11562980b8..e56a3f0b6d12 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -525,7 +525,7 @@ public void resizingLargeMap() { } @Test - public void testTotalMemoryConsumption() { + public void testPeakMemoryUsed() { final long recordLengthBytes = 24; final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; @@ -536,8 +536,8 @@ public void testTotalMemoryConsumption() { // monotonically increasing. More specifically, every time we allocate a new page it // should increase by exactly the size of the page. In this regard, the memory usage // at any given time is also the peak memory used. - long previousMemory = map.getTotalMemoryConsumption(); - long newMemory; + long previousPeakMemory = map.getPeakMemoryUsedBytes(); + long newPeakMemory; try { for (long i = 0; i < numRecordsPerPage * 10; i++) { final long[] value = new long[]{i}; @@ -548,15 +548,21 @@ public void testTotalMemoryConsumption() { value, PlatformDependent.LONG_ARRAY_OFFSET, 8); - newMemory = map.getTotalMemoryConsumption(); + newPeakMemory = map.getPeakMemoryUsedBytes(); if (i % numRecordsPerPage == 0) { // We allocated a new page for this record, so peak memory should change - assertEquals(previousMemory + pageSizeBytes, newMemory); + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { - assertEquals(previousMemory, newMemory); + assertEquals(previousPeakMemory, newPeakMemory); } - previousMemory = newMemory; + previousPeakMemory = newPeakMemory; } + + // Freeing the map should not change the peak memory + map.free(); + newPeakMemory = map.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { map.free(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index efb33530dac8..b08a4a13a28b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -210,11 +210,10 @@ public void close() { } /** - * The memory used by this map's managed structures, in bytes. - * Note that this is also the peak memory used by this map, since the map is append-only. + * Return the peak memory used so far, in bytes. */ - public long getMemoryUsage() { - return map.getTotalMemoryConsumption(); + public long getPeakMemoryUsedBytes() { + return map.getPeakMemoryUsedBytes(); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 9a65c9d3a404..69d6784713a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -159,6 +159,13 @@ public KVSorterIterator sortedIterator() throws IOException { } } + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + return sorter.getPeakMemoryUsedBytes(); + } + /** * Marks the current page as no-more-space-available, and as a result, either allocate a * new page or spill when we see the next record. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 4d5e98a3e90c..440bef32f4e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner @@ -397,14 +397,20 @@ class TungstenAggregationIterator( private[this] var mapIteratorHasNext: Boolean = false /////////////////////////////////////////////////////////////////////////// - // Part 4: The function used to switch this iterator from hash-based - // aggregation to sort-based aggregation. + // Part 3: Methods and fields used by sort-based aggregation. /////////////////////////////////////////////////////////////////////////// + // This sorter is used for sort-based aggregation. It is initialized as soon as + // we switch from hash-based to sort-based aggregation. Otherwise, it is not used. + private[this] var externalSorter: UnsafeKVExternalSorter = null + + /** + * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. + */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. - val externalSorter: UnsafeKVExternalSorter = hashMap.destructAndCreateExternalSorter() + externalSorter = hashMap.destructAndCreateExternalSorter() // Step 2: Free the memory used by the map. hashMap.free() @@ -601,7 +607,7 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Par 7: Iterator's public methods. + // Part 7: Iterator's public methods. /////////////////////////////////////////////////////////////////////////// override final def hasNext: Boolean = { @@ -610,7 +616,7 @@ class TungstenAggregationIterator( override final def next(): UnsafeRow = { if (hasNext) { - if (sortBased) { + val res = if (sortBased) { // Process the current group. processCurrentSortedGroup() // Generate output row for the current group. @@ -641,6 +647,19 @@ class TungstenAggregationIterator( result } } + + // If this is the last record, update the task's peak memory usage. Since we destroy + // the map to create the sorter, their memory usages should not overlap, so it is safe + // to just use the max of the two. + if (!hasNext) { + val mapMemory = hashMap.getPeakMemoryUsedBytes + val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L) + val peakMemory = Math.max(mapMemory, sorterMemory) + TaskContext.get().internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) + } + + res } else { // no more result throw new NoSuchElementException @@ -651,6 +670,7 @@ class TungstenAggregationIterator( // Part 8: A utility function used to generate a output row when there is no // input and there is no grouping expression. /////////////////////////////////////////////////////////////////////////// + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { if (groupingExpressions.isEmpty) { sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c64aa7a07dc2..b14ef9bab90c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -267,7 +267,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { if (!hasGeneratedAgg) { fail( s""" - |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |Codegen is enabled, but query $sqlText does not have TungstenAggregate in the plan. |${df.queryExecution.simpleString} """.stripMargin) } @@ -1602,10 +1602,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } - ignore("aggregation with codegen updates peak execution memory") { - withSQLConf( - (SQLConf.CODEGEN_ENABLED.key, "true"), - (SQLConf.USE_SQL_AGGREGATE2.key, "false")) { + test("aggregation with codegen updates peak execution memory") { + withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { val sc = sqlContext.sparkContext AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { testCodeGen( From e2fbbe73111d4624390f596a19a1799c86a05f6c Mon Sep 17 00:00:00 2001 From: Dariusz Kobylarz Date: Fri, 7 Aug 2015 14:51:03 -0700 Subject: [PATCH 0122/1824] [SPARK-8481] [MLLIB] GaussianMixtureModel predict accepting single vector Resubmit of [https://github.com/apache/spark/pull/6906] for adding single-vec predict to GMMs CC: dkobylarz mengxr To be merged with master and branch-1.5 Primary author: dkobylarz Author: Dariusz Kobylarz Closes #8039 from jkbradley/gmm-predict-vec and squashes the following commits: bfbedc4 [Dariusz Kobylarz] [SPARK-8481] [MLlib] GaussianMixtureModel predict accepting single vector --- .../mllib/clustering/GaussianMixtureModel.scala | 13 +++++++++++++ .../mllib/clustering/GaussianMixtureSuite.scala | 10 ++++++++++ 2 files changed, 23 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index cb807c803810..76aeebd703d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -66,6 +66,12 @@ class GaussianMixtureModel( responsibilityMatrix.map(r => r.indexOf(r.max)) } + /** Maps given point to its cluster index. */ + def predict(point: Vector): Int = { + val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + r.indexOf(r.max) + } + /** Java-friendly version of [[predict()]] */ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] @@ -83,6 +89,13 @@ class GaussianMixtureModel( } } + /** + * Given the input vector, return the membership values to all mixture components. + */ + def predictSoft(point: Vector): Array[Double] = { + computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + } + /** * Compute the partial assignments for each vector */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index b218d72f1268..b636d02f786e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("model prediction, parallel and local") { + val data = sc.parallelize(GaussianTestData.data) + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + + val batchPredictions = gmm.predict(data) + batchPredictions.zip(data).collect().foreach { case (batchPred, datum) => + assert(batchPred === gmm.predict(datum)) + } + } + object GaussianTestData { val data = Array( From 902334fd55bbe40a57c1de2a9bdb25eddf1c8cf6 Mon Sep 17 00:00:00 2001 From: Bertrand Dechoux Date: Fri, 7 Aug 2015 16:07:24 -0700 Subject: [PATCH 0123/1824] [SPARK-9748] [MLLIB] Centriod typo in KMeansModel A minor typo (centriod -> centroid). Readable variable names help every users. Author: Bertrand Dechoux Closes #8037 from BertrandDechoux/kmeans-typo and squashes the following commits: 47632fe [Bertrand Dechoux] centriod typo --- .../apache/spark/mllib/clustering/KMeansModel.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 8ecb3df11d95..96359024fa22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -120,11 +120,11 @@ object KMeansModel extends Loader[KMeansModel] { assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val centriods = sqlContext.read.parquet(Loader.dataPath(path)) - Loader.checkSchema[Cluster](centriods.schema) - val localCentriods = centriods.map(Cluster.apply).collect() - assert(k == localCentriods.size) - new KMeansModel(localCentriods.sortBy(_.id).map(_.point)) + val centroids = sqlContext.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centroids.schema) + val localCentroids = centroids.map(Cluster.apply).collect() + assert(k == localCentroids.size) + new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } } } From 49702bd738de681255a7177339510e0e1b25a8db Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 7 Aug 2015 16:24:50 -0700 Subject: [PATCH 0124/1824] [SPARK-8890] [SQL] Fallback on sorting when writing many dynamic partitions Previously, we would open a new file for each new dynamic written out using `HadoopFsRelation`. For formats like parquet this is very costly due to the buffers required to get good compression. In this PR I refactor the code allowing us to fall back on an external sort when many partitions are seen. As such each task will open no more than `spark.sql.sources.maxFiles` files. I also did the following cleanup: - Instead of keying the file HashMap on an expensive to compute string representation of the partition, we now use a fairly cheap UnsafeProjection that avoids heap allocations. - The control flow for instantiating and invoking a writer container has been simplified. Now instead of switching in two places based on the use of partitioning, the specific writer container must implement a single method `writeRows` that is invoked using `runJob`. - `InternalOutputWriter` has been removed. Instead we have a `private[sql]` method `writeInternal` that converts and calls the public method. This method can be overridden by internal datasources to avoid the conversion. This change remove a lot of code duplication and per-row `asInstanceOf` checks. - `commands.scala` has been split up. Author: Michael Armbrust Closes #8010 from marmbrus/fsWriting and squashes the following commits: 00804fe [Michael Armbrust] use shuffleMemoryManager.pageSizeBytes 775cc49 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into fsWriting 17b690e [Michael Armbrust] remove comment 40f0372 [Michael Armbrust] address comments f5675bd [Michael Armbrust] char -> string 7e2d0a4 [Michael Armbrust] make sure we close current writer 8100100 [Michael Armbrust] delete empty commands.scala 71cc717 [Michael Armbrust] update comment 8ec75ac [Michael Armbrust] [SPARK-8890][SQL] Fallback on sorting when writing many dynamic partitions --- .../scala/org/apache/spark/sql/SQLConf.scala | 8 +- .../datasources/InsertIntoDataSource.scala | 64 ++ .../InsertIntoHadoopFsRelation.scala | 165 +++++ .../datasources/WriterContainer.scala | 404 ++++++++++++ .../sql/execution/datasources/commands.scala | 606 ------------------ .../apache/spark/sql/json/JSONRelation.scala | 6 +- .../spark/sql/parquet/ParquetRelation.scala | 6 +- .../apache/spark/sql/sources/interfaces.scala | 17 +- .../sql/sources/PartitionedWriteSuite.scala | 56 ++ .../spark/sql/hive/orc/OrcRelation.scala | 6 +- 10 files changed, 715 insertions(+), 623 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 45d3d8c86351..e9de14f02550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -366,17 +366,21 @@ private[spark] object SQLConf { "storing additional schema information in Hive's metastore.", isPublic = false) - // Whether to perform partition discovery when loading external data sources. Default to true. val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", defaultValue = Some(true), doc = "When true, automtically discover data partitions.") - // Whether to perform partition column type inference. Default to true. val PARTITION_COLUMN_TYPE_INFERENCE = booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", defaultValue = Some(true), doc = "When true, automatically infer the data types for partitioned columns.") + val PARTITION_MAX_FILES = + intConf("spark.sql.sources.maxConcurrentWrites", + defaultValue = Some(5), + doc = "The maximum number of concurent files to open before falling back on sorting when " + + "writing out files using dynamic partitioning.") + // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. // diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala new file mode 100644 index 000000000000..6ccde7693bd3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.IOException +import java.util.{Date, UUID} + +import scala.collection.JavaConversions.asScalaIterator + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} +import org.apache.spark._ +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.{Utils, SerializableConfiguration} + + +/** + * Inserts the results of `query` in to a relation that extends [[InsertableRelation]]. + */ +private[sql] case class InsertIntoDataSource( + logicalRelation: LogicalRelation, + query: LogicalPlan, + overwrite: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] + val data = DataFrame(sqlContext, query) + // Apply the schema of the existing table to the new data. + val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + relation.insert(df, overwrite) + + // Invalidate the cache. + sqlContext.cacheManager.invalidateCache(logicalRelation) + + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala new file mode 100644 index 000000000000..735d52f80886 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.IOException + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.spark._ +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} +import org.apache.spark.sql.sources._ +import org.apache.spark.util.Utils + + +/** + * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. + * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a + * single write job, and owns a UUID that identifies this job. Each concrete implementation of + * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for + * each task output file. This UUID is passed to executor side via a property named + * `spark.sql.sources.writeJobUUID`. + * + * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] + * are used to write to normal tables and tables with dynamic partitions. + * + * Basic work flow of this command is: + * + * 1. Driver side setup, including output committer initialization and data source specific + * preparation work for the write job to be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ +private[sql] case class InsertIntoHadoopFsRelation( + @transient relation: HadoopFsRelation, + @transient query: LogicalPlan, + mode: SaveMode) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + require( + relation.paths.length == 1, + s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") + + val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val outputPath = new Path(relation.paths.head) + val fs = outputPath.getFileSystem(hadoopConf) + val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + + val pathExists = fs.exists(qualifiedOutputPath) + val doInsertion = (mode, pathExists) match { + case (SaveMode.ErrorIfExists, true) => + throw new AnalysisException(s"path $qualifiedOutputPath already exists.") + case (SaveMode.Overwrite, true) => + Utils.tryOrIOException { + if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { + throw new IOException(s"Unable to clear output " + + s"directory $qualifiedOutputPath prior to writing to it") + } + } + true + case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => + true + case (SaveMode.Ignore, exists) => + !exists + case (s, exists) => + throw new IllegalStateException(s"unsupported save mode $s ($exists)") + } + // If we are appending data to an existing dir. + val isAppend = pathExists && (mode == SaveMode.Append) + + if (doInsertion) { + val job = new Job(hadoopConf) + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[InternalRow]) + FileOutputFormat.setOutputPath(job, qualifiedOutputPath) + + // A partitioned relation schema's can be different from the input logicalPlan, since + // partition columns are all moved after data column. We Project to adjust the ordering. + // TODO: this belongs in the analyzer. + val project = Project( + relation.schema.map(field => UnresolvedAttribute.quoted(field.name)), query) + val queryExecution = DataFrame(sqlContext, project).queryExecution + + SQLExecution.withNewExecutionId(sqlContext, queryExecution) { + val df = sqlContext.internalCreateDataFrame(queryExecution.toRdd, relation.schema) + val partitionColumns = relation.partitionColumns.fieldNames + + // Some pre-flight checks. + require( + df.schema == relation.schema, + s"""DataFrame must have the same schema as the relation to which is inserted. + |DataFrame schema: ${df.schema} + |Relation schema: ${relation.schema} + """.stripMargin) + val partitionColumnsInSpec = relation.partitionColumns.fieldNames + require( + partitionColumnsInSpec.sameElements(partitionColumns), + s"""Partition columns mismatch. + |Expected: ${partitionColumnsInSpec.mkString(", ")} + |Actual: ${partitionColumns.mkString(", ")} + """.stripMargin) + + val writerContainer = if (partitionColumns.isEmpty) { + new DefaultWriterContainer(relation, job, isAppend) + } else { + val output = df.queryExecution.executedPlan.output + val (partitionOutput, dataOutput) = + output.partition(a => partitionColumns.contains(a.name)) + + new DynamicPartitionWriterContainer( + relation, + job, + partitionOutput, + dataOutput, + output, + PartitioningUtils.DEFAULT_PARTITION_NAME, + sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), + isAppend) + } + + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + + try { + sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writerContainer.writeRows _) + writerContainer.commitJob() + relation.refresh() + } catch { case cause: Throwable => + logError("Aborting job.", cause) + writerContainer.abortJob() + throw new SparkException("Job aborted.", cause) + } + } + } else { + logInfo("Skipping insertion into a relation that already exists.") + } + + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala new file mode 100644 index 000000000000..2f11f4042240 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.{Date, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} +import org.apache.spark._ +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.{StructType, StringType} +import org.apache.spark.util.SerializableConfiguration + + +private[sql] abstract class BaseWriterContainer( + @transient val relation: HadoopFsRelation, + @transient job: Job, + isAppend: Boolean) + extends SparkHadoopMapReduceUtil + with Logging + with Serializable { + + protected val dataSchema = relation.dataSchema + + protected val serializableConf = new SerializableConfiguration(job.getConfiguration) + + // This UUID is used to avoid output file name collision between different appending write jobs. + // These jobs may belong to different SparkContext instances. Concrete data source implementations + // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). + // The reason why this ID is used to identify a job rather than a single task output file is + // that, speculative tasks must generate the same output file name as the original task. + private val uniqueWriteJobId = UUID.randomUUID() + + // This is only used on driver side. + @transient private val jobContext: JobContext = job + + // The following fields are initialized and used on both driver and executor side. + @transient protected var outputCommitter: OutputCommitter = _ + @transient private var jobId: JobID = _ + @transient private var taskId: TaskID = _ + @transient private var taskAttemptId: TaskAttemptID = _ + @transient protected var taskAttemptContext: TaskAttemptContext = _ + + protected val outputPath: String = { + assert( + relation.paths.length == 1, + s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") + relation.paths.head + } + + protected var outputWriterFactory: OutputWriterFactory = _ + + private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit + + def driverSideSetup(): Unit = { + setupIDs(0, 0, 0) + setupConf() + + // This UUID is sent to executor side together with the serialized `Configuration` object within + // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate + // unique task output files. + job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + + // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor + // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, + // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. + // + // Also, the `prepareJobForWrite` call must happen before initializing output format and output + // committer, since their initialization involve the job configuration, which can be potentially + // decorated in `prepareJobForWrite`. + outputWriterFactory = relation.prepareJobForWrite(job) + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + + outputFormatClass = job.getOutputFormatClass + outputCommitter = newOutputCommitter(taskAttemptContext) + outputCommitter.setupJob(jobContext) + } + + def executorSideSetup(taskContext: TaskContext): Unit = { + setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) + setupConf() + taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + outputCommitter = newOutputCommitter(taskAttemptContext) + outputCommitter.setupTask(taskAttemptContext) + } + + protected def getWorkPath: String = { + outputCommitter match { + // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. + case f: MapReduceFileOutputCommitter => f.getWorkPath.toString + case _ => outputPath + } + } + + private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) + + if (isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the the appending job fails. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + + "for appending.") + defaultOutputCommitter + } else { + val committerClass = context.getConfiguration.getClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + Option(committerClass).map { clazz => + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + ctor.newInstance(new Path(outputPath), context) + } else { + // The specified output committer is just a OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + ctor.newInstance() + } + }.getOrElse { + // If output committer class is not set, we will use the one associated with the + // file output format. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") + defaultOutputCommitter + } + } + } + + private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { + this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) + this.taskId = new TaskID(this.jobId, true, splitId) + this.taskAttemptId = new TaskAttemptID(taskId, attemptId) + } + + private def setupConf(): Unit = { + serializableConf.value.set("mapred.job.id", jobId.toString) + serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + serializableConf.value.set("mapred.task.id", taskAttemptId.toString) + serializableConf.value.setBoolean("mapred.task.is.map", true) + serializableConf.value.setInt("mapred.task.partition", 0) + } + + def commitTask(): Unit = { + SparkHadoopMapRedUtil.commitTask( + outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) + } + + def abortTask(): Unit = { + if (outputCommitter != null) { + outputCommitter.abortTask(taskAttemptContext) + } + logError(s"Task attempt $taskAttemptId aborted.") + } + + def commitJob(): Unit = { + outputCommitter.commitJob(jobContext) + logInfo(s"Job $jobId committed.") + } + + def abortJob(): Unit = { + if (outputCommitter != null) { + outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + } + logError(s"Job $jobId aborted.") + } +} + +/** + * A writer that writes all of the rows in a partition to a single file. + */ +private[sql] class DefaultWriterContainer( + @transient relation: HadoopFsRelation, + @transient job: Job, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + executorSideSetup(taskContext) + taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) + val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) + writer.initConverter(dataSchema) + + // If anything below fails, we should abort the task. + try { + while (iterator.hasNext) { + val internalRow = iterator.next() + writer.writeInternal(internalRow) + } + + commitTask() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + + def commitTask(): Unit = { + try { + assert(writer != null, "OutputWriter instance should have been initialized") + writer.close() + super.commitTask() + } catch { + case cause: Throwable => + // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and + // will cause `abortTask()` to be invoked. + throw new RuntimeException("Failed to commit task", cause) + } + } + + def abortTask(): Unit = { + try { + writer.close() + } finally { + super.abortTask() + } + } + } +} + +/** + * A writer that dynamically opens files based on the given partition columns. Internally this is + * done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the + * writer externally sorts the remaining rows and then writes out them out one file at a time. + */ +private[sql] class DynamicPartitionWriterContainer( + @transient relation: HadoopFsRelation, + @transient job: Job, + partitionColumns: Seq[Attribute], + dataColumns: Seq[Attribute], + inputSchema: Seq[Attribute], + defaultPartitionName: String, + maxOpenFiles: Int, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] + executorSideSetup(taskContext) + + // Returns the partition key given an input row + val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + + // Expressions that given a partition key build a string like: col1=val/col2=val/... + val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => + val escaped = + ScalaUDF( + PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) + val str = If(IsNull(c), Literal(defaultPartitionName), escaped) + val partitionName = Literal(c.name + "=") :: str :: Nil + if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName + } + + // Returns the partition path given a partition key. + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) + + // If anything below fails, we should abort the task. + try { + // This will be filled in if we have to fall back on sorting. + var sorter: UnsafeKVExternalSorter = null + while (iterator.hasNext && sorter == null) { + val inputRow = iterator.next() + val currentKey = getPartitionKey(inputRow) + var currentWriter = outputWriters.get(currentKey) + + if (currentWriter == null) { + if (outputWriters.size < maxOpenFiles) { + currentWriter = newOutputWriter(currentKey) + outputWriters.put(currentKey.copy(), currentWriter) + currentWriter.writeInternal(getOutputRow(inputRow)) + } else { + logInfo(s"Maximum partitions reached, falling back on sorting.") + sorter = new UnsafeKVExternalSorter( + StructType.fromAttributes(partitionColumns), + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + SparkEnv.get.shuffleMemoryManager, + SparkEnv.get.shuffleMemoryManager.pageSizeBytes) + sorter.insertKV(currentKey, getOutputRow(inputRow)) + } + } else { + currentWriter.writeInternal(getOutputRow(inputRow)) + } + } + + // If the sorter is not null that means that we reached the maxFiles above and need to finish + // using external sort. + if (sorter != null) { + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val sortedIterator = sorter.sortedIterator() + var currentKey: InternalRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + if (currentKey != sortedIterator.getKey) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + + // Either use an existing file from before, or open a new one. + currentWriter = outputWriters.remove(currentKey) + if (currentWriter == null) { + currentWriter = newOutputWriter(currentKey) + } + } + + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } + } + } + + commitTask() + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + + /** Open and returns a new OutputWriter given a partition key. */ + def newOutputWriter(key: InternalRow): OutputWriter = { + val partitionPath = getPartitionString(key).getString(0) + val path = new Path(getWorkPath, partitionPath) + taskAttemptContext.getConfiguration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) + newWriter.initConverter(dataSchema) + newWriter + } + + def clearOutputWriters(): Unit = { + outputWriters.asScala.values.foreach(_.close()) + outputWriters.clear() + } + + def commitTask(): Unit = { + try { + clearOutputWriters() + super.commitTask() + } catch { + case cause: Throwable => + throw new RuntimeException("Failed to commit task", cause) + } + } + + def abortTask(): Unit = { + try { + clearOutputWriters() + } finally { + super.abortTask() + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala deleted file mode 100644 index 42668979c9a3..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala +++ /dev/null @@ -1,606 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import java.io.IOException -import java.util.{Date, UUID} - -import scala.collection.JavaConversions.asScalaIterator - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.spark._ -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StringType -import org.apache.spark.util.{Utils, SerializableConfiguration} - - -private[sql] case class InsertIntoDataSource( - logicalRelation: LogicalRelation, - query: LogicalPlan, - overwrite: Boolean) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] - val data = DataFrame(sqlContext, query) - // Apply the schema of the existing table to the new data. - val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) - - // Invalidate the cache. - sqlContext.cacheManager.invalidateCache(logicalRelation) - - Seq.empty[Row] - } -} - -/** - * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. - * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a - * single write job, and owns a UUID that identifies this job. Each concrete implementation of - * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for - * each task output file. This UUID is passed to executor side via a property named - * `spark.sql.sources.writeJobUUID`. - * - * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] - * are used to write to normal tables and tables with dynamic partitions. - * - * Basic work flow of this command is: - * - * 1. Driver side setup, including output committer initialization and data source specific - * preparation work for the write job to be issued. - * 2. Issues a write job consists of one or more executor side tasks, each of which writes all - * rows within an RDD partition. - * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any - * exception is thrown during task commitment, also aborts that task. - * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is - * thrown during job commitment, also aborts the job. - */ -private[sql] case class InsertIntoHadoopFsRelation( - @transient relation: HadoopFsRelation, - @transient query: LogicalPlan, - mode: SaveMode) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - require( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") - - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration - val outputPath = new Path(relation.paths.head) - val fs = outputPath.getFileSystem(hadoopConf) - val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - - val pathExists = fs.exists(qualifiedOutputPath) - val doInsertion = (mode, pathExists) match { - case (SaveMode.ErrorIfExists, true) => - throw new AnalysisException(s"path $qualifiedOutputPath already exists.") - case (SaveMode.Overwrite, true) => - Utils.tryOrIOException { - if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { - throw new IOException(s"Unable to clear output " + - s"directory $qualifiedOutputPath prior to writing to it") - } - } - true - case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => - true - case (SaveMode.Ignore, exists) => - !exists - case (s, exists) => - throw new IllegalStateException(s"unsupported save mode $s ($exists)") - } - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) - - if (doInsertion) { - val job = new Job(hadoopConf) - job.setOutputKeyClass(classOf[Void]) - job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, qualifiedOutputPath) - - // We create a DataFrame by applying the schema of relation to the data to make sure. - // We are writing data based on the expected schema, - - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). We - // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can - // safely apply the schema of r.schema to the data. - val project = Project( - relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) - - val queryExecution = DataFrame(sqlContext, project).queryExecution - SQLExecution.withNewExecutionId(sqlContext, queryExecution) { - val df = sqlContext.internalCreateDataFrame(queryExecution.toRdd, relation.schema) - - val partitionColumns = relation.partitionColumns.fieldNames - if (partitionColumns.isEmpty) { - insert(new DefaultWriterContainer(relation, job, isAppend), df) - } else { - val writerContainer = new DynamicPartitionWriterContainer( - relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME, isAppend) - insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns) - } - } - } - - Seq.empty[Row] - } - - /** - * Inserts the content of the [[DataFrame]] into a table without any partitioning columns. - */ - private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = { - // Uses local vals for serialization - val needsConversion = relation.needConversion - val dataSchema = relation.dataSchema - - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - writerContainer.driverSideSetup() - - try { - df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) - writerContainer.commitJob() - relation.refresh() - } catch { case cause: Throwable => - logError("Aborting job.", cause) - writerContainer.abortJob() - throw new SparkException("Job aborted.", cause) - } - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - // If anything below fails, we should abort the task. - try { - writerContainer.executorSideSetup(taskContext) - - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - .asInstanceOf[InternalRow => Row] - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) - } - } else { - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow) - .asInstanceOf[OutputWriterInternal].writeInternal(internalRow) - } - } - - writerContainer.commitTask() - } catch { case cause: Throwable => - logError("Aborting task.", cause) - writerContainer.abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - } - } - - /** - * Inserts the content of the [[DataFrame]] into a table with partitioning columns. - */ - private def insertWithDynamicPartitions( - sqlContext: SQLContext, - writerContainer: BaseWriterContainer, - df: DataFrame, - partitionColumns: Array[String]): Unit = { - // Uses a local val for serialization - val needsConversion = relation.needConversion - val dataSchema = relation.dataSchema - - require( - df.schema == relation.schema, - s"""DataFrame must have the same schema as the relation to which is inserted. - |DataFrame schema: ${df.schema} - |Relation schema: ${relation.schema} - """.stripMargin) - - val partitionColumnsInSpec = relation.partitionColumns.fieldNames - require( - partitionColumnsInSpec.sameElements(partitionColumns), - s"""Partition columns mismatch. - |Expected: ${partitionColumnsInSpec.mkString(", ")} - |Actual: ${partitionColumns.mkString(", ")} - """.stripMargin) - - val output = df.queryExecution.executedPlan.output - val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) - val codegenEnabled = df.sqlContext.conf.codegenEnabled - - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - writerContainer.driverSideSetup() - - try { - df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) - writerContainer.commitJob() - relation.refresh() - } catch { case cause: Throwable => - logError("Aborting job.", cause) - writerContainer.abortJob() - throw new SparkException("Job aborted.", cause) - } - - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - // If anything below fails, we should abort the task. - try { - writerContainer.executorSideSetup(taskContext) - - // Projects all partition columns and casts them to strings to build partition directories. - val partitionCasts = partitionOutput.map(Cast(_, StringType)) - val partitionProj = newProjection(codegenEnabled, partitionCasts, output) - val dataProj = newProjection(codegenEnabled, dataOutput, output) - - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - .asInstanceOf[InternalRow => Row] - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = converter(dataProj(internalRow)) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) - } - } else { - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = dataProj(internalRow) - writerContainer.outputWriterForRow(partitionPart) - .asInstanceOf[OutputWriterInternal].writeInternal(dataPart) - } - } - - writerContainer.commitTask() - } catch { case cause: Throwable => - logError("Aborting task.", cause) - writerContainer.abortTask() - throw new SparkException("Task failed while writing rows.", cause) - } - } - } - - // This is copied from SparkPlan, probably should move this to a more general place. - private def newProjection( - codegenEnabled: Boolean, - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (sys.props.contains("spark.testing")) { - throw e - } else { - log.error("failed to generate projection, fallback to interpreted", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } else { - new InterpretedProjection(expressions, inputSchema) - } - } -} - -private[sql] abstract class BaseWriterContainer( - @transient val relation: HadoopFsRelation, - @transient job: Job, - isAppend: Boolean) - extends SparkHadoopMapReduceUtil - with Logging - with Serializable { - - protected val serializableConf = new SerializableConfiguration(job.getConfiguration) - - // This UUID is used to avoid output file name collision between different appending write jobs. - // These jobs may belong to different SparkContext instances. Concrete data source implementations - // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). - // The reason why this ID is used to identify a job rather than a single task output file is - // that, speculative tasks must generate the same output file name as the original task. - private val uniqueWriteJobId = UUID.randomUUID() - - // This is only used on driver side. - @transient private val jobContext: JobContext = job - - // The following fields are initialized and used on both driver and executor side. - @transient protected var outputCommitter: OutputCommitter = _ - @transient private var jobId: JobID = _ - @transient private var taskId: TaskID = _ - @transient private var taskAttemptId: TaskAttemptID = _ - @transient protected var taskAttemptContext: TaskAttemptContext = _ - - protected val outputPath: String = { - assert( - relation.paths.length == 1, - s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") - relation.paths.head - } - - protected val dataSchema = relation.dataSchema - - protected var outputWriterFactory: OutputWriterFactory = _ - - private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ - - def driverSideSetup(): Unit = { - setupIDs(0, 0, 0) - setupConf() - - // This UUID is sent to executor side together with the serialized `Configuration` object within - // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate - // unique task output files. - job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) - - // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor - // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, - // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. - // - // Also, the `prepareJobForWrite` call must happen before initializing output format and output - // committer, since their initialization involve the job configuration, which can be potentially - // decorated in `prepareJobForWrite`. - outputWriterFactory = relation.prepareJobForWrite(job) - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - - outputFormatClass = job.getOutputFormatClass - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupJob(jobContext) - } - - def executorSideSetup(taskContext: TaskContext): Unit = { - setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) - setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - outputCommitter = newOutputCommitter(taskAttemptContext) - outputCommitter.setupTask(taskAttemptContext) - initWriters() - } - - protected def getWorkPath: String = { - outputCommitter match { - // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. - case f: MapReduceFileOutputCommitter => f.getWorkPath.toString - case _ => outputPath - } - } - - private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) - - if (isAppend) { - // If we are appending data to an existing dir, we will only use the output committer - // associated with the file output format since it is not safe to use a custom - // committer for appending. For example, in S3, direct parquet output committer may - // leave partial data in the destination dir when the the appending job fails. - logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + - "for appending.") - defaultOutputCommitter - } else { - val committerClass = context.getConfiguration.getClass( - SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) - - Option(committerClass).map { clazz => - logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") - - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(outputPath), context) - } else { - // The specified output committer is just a OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - ctor.newInstance() - } - }.getOrElse { - // If output committer class is not set, we will use the one associated with the - // file output format. - logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") - defaultOutputCommitter - } - } - } - - private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { - this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) - this.taskId = new TaskID(this.jobId, true, splitId) - this.taskAttemptId = new TaskAttemptID(taskId, attemptId) - } - - private def setupConf(): Unit = { - serializableConf.value.set("mapred.job.id", jobId.toString) - serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - serializableConf.value.set("mapred.task.id", taskAttemptId.toString) - serializableConf.value.setBoolean("mapred.task.is.map", true) - serializableConf.value.setInt("mapred.task.partition", 0) - } - - // Called on executor side when writing rows - def outputWriterForRow(row: InternalRow): OutputWriter - - protected def initWriters(): Unit - - def commitTask(): Unit = { - SparkHadoopMapRedUtil.commitTask( - outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) - } - - def abortTask(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortTask(taskAttemptContext) - } - logError(s"Task attempt $taskAttemptId aborted.") - } - - def commitJob(): Unit = { - outputCommitter.commitJob(jobContext) - logInfo(s"Job $jobId committed.") - } - - def abortJob(): Unit = { - if (outputCommitter != null) { - outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) - } - logError(s"Job $jobId aborted.") - } -} - -private[sql] class DefaultWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { - - @transient private var writer: OutputWriter = _ - - override protected def initWriters(): Unit = { - taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) - writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) - } - - override def outputWriterForRow(row: InternalRow): OutputWriter = writer - - override def commitTask(): Unit = { - try { - assert(writer != null, "OutputWriter instance should have been initialized") - writer.close() - super.commitTask() - } catch { case cause: Throwable => - // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and will - // cause `abortTask()` to be invoked. - throw new RuntimeException("Failed to commit task", cause) - } - } - - override def abortTask(): Unit = { - try { - // It's possible that the task fails before `writer` gets initialized - if (writer != null) { - writer.close() - } - } finally { - super.abortTask() - } - } -} - -private[sql] class DynamicPartitionWriterContainer( - @transient relation: HadoopFsRelation, - @transient job: Job, - partitionColumns: Array[String], - defaultPartitionName: String, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { - - // All output writers are created on executor side. - @transient protected var outputWriters: java.util.HashMap[String, OutputWriter] = _ - - override protected def initWriters(): Unit = { - outputWriters = new java.util.HashMap[String, OutputWriter] - } - - // The `row` argument is supposed to only contain partition column values which have been casted - // to strings. - override def outputWriterForRow(row: InternalRow): OutputWriter = { - val partitionPath = { - val partitionPathBuilder = new StringBuilder - var i = 0 - - while (i < partitionColumns.length) { - val col = partitionColumns(i) - val partitionValueString = { - val string = row.getUTF8String(i) - if (string.eq(null)) { - defaultPartitionName - } else { - PartitioningUtils.escapePathName(string.toString) - } - } - - if (i > 0) { - partitionPathBuilder.append(Path.SEPARATOR_CHAR) - } - - partitionPathBuilder.append(s"$col=$partitionValueString") - i += 1 - } - - partitionPathBuilder.toString() - } - - val writer = outputWriters.get(partitionPath) - if (writer.eq(null)) { - val path = new Path(getWorkPath, partitionPath) - taskAttemptContext.getConfiguration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) - outputWriters.put(partitionPath, newWriter) - newWriter - } else { - writer - } - } - - private def clearOutputWriters(): Unit = { - if (!outputWriters.isEmpty) { - asScalaIterator(outputWriters.values().iterator()).foreach(_.close()) - outputWriters.clear() - } - } - - override def commitTask(): Unit = { - try { - clearOutputWriters() - super.commitTask() - } catch { case cause: Throwable => - throw new RuntimeException("Failed to commit task", cause) - } - } - - override def abortTask(): Unit = { - try { - clearOutputWriters() - } finally { - super.abortTask() - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 5d371402877c..10f1367e6984 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -152,7 +152,7 @@ private[json] class JsonOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriterInternal with SparkHadoopMapRedUtil with Logging { + extends OutputWriter with SparkHadoopMapRedUtil with Logging { val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records @@ -170,7 +170,9 @@ private[json] class JsonOutputWriter( }.getRecordWriter(context) } - override def writeInternal(row: InternalRow): Unit = { + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { JacksonGenerator(dataSchema, gen, row) gen.flush() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 29c388c22ef9..48009b2fd007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -62,7 +62,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriterInternal { + extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { val outputFormat = { @@ -87,7 +87,9 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext outputFormat.getRecordWriter(context) } - override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 0b2929661b65..c5b7ee73eb78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -342,18 +342,17 @@ abstract class OutputWriter { * @since 1.4.0 */ def close(): Unit -} -/** - * This is an internal, private version of [[OutputWriter]] with an writeInternal method that - * accepts an [[InternalRow]] rather than an [[Row]]. Data sources that return this must have - * the conversion flag set to false. - */ -private[sql] abstract class OutputWriterInternal extends OutputWriter { + private var converter: InternalRow => Row = _ - override def write(row: Row): Unit = throw new UnsupportedOperationException + protected[sql] def initConverter(dataSchema: StructType) = { + converter = + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + } - def writeInternal(row: InternalRow): Unit + protected[sql] def writeInternal(row: InternalRow): Unit = { + write(converter(row)) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala new file mode 100644 index 000000000000..c86ddd7c83e5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.util.Utils + +class PartitionedWriteSuite extends QueryTest { + import TestSQLContext.implicits._ + + test("write many partitions") { + val path = Utils.createTempDir() + path.delete() + + val df = TestSQLContext.range(100).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + TestSQLContext.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } + + test("write many partitions with repeats") { + val path = Utils.createTempDir() + path.delete() + + val base = TestSQLContext.range(100) + val df = base.unionAll(base).select($"id", lit(1).as("data")) + df.write.partitionBy("id").save(path.getCanonicalPath) + + checkAnswer( + TestSQLContext.read.load(path.getCanonicalPath), + (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) + + Utils.deleteRecursively(path) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 4a310ff4e901..7c8704b47f28 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -66,7 +66,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriterInternal with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -120,7 +120,9 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def writeInternal(row: InternalRow): Unit = { + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { var i = 0 while (i < row.numFields) { reusableOutputBuffer(i) = wrappers(i)(row.get(i, dataSchema(i).dataType)) From cd540c1e59561ad1fdac59af6170944c60e685d8 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 7 Aug 2015 17:19:48 -0700 Subject: [PATCH 0125/1824] [SPARK-9756] [ML] Make constructors in ML decision trees private These should be made private until there is a public constructor for providing `rootNode: Node` to use these constructors. jkbradley Author: Feynman Liang Closes #8046 from feynmanliang/SPARK-9756 and squashes the following commits: 2cbdf08 [Feynman Liang] Make RFRegressionModel aux constructor private a06f596 [Feynman Liang] Make constructors in ML decision trees private --- .../spark/ml/classification/DecisionTreeClassifier.scala | 2 +- .../spark/ml/classification/RandomForestClassifier.scala | 5 ++++- .../apache/spark/ml/regression/DecisionTreeRegressor.scala | 2 +- .../apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index f2b992f8ba24..29598f3f05c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -117,7 +117,7 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node, numClasses: Int) = + private[ml] def this(rootNode: Node, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numClasses) override protected def predict(features: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index b59826a59499..156050aaf7a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -136,7 +136,10 @@ final class RandomForestClassificationModel private[ml] ( * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeClassificationModel], numFeatures: Int, numClasses: Int) = + private[ml] def this( + trees: Array[DecisionTreeClassificationModel], + numFeatures: Int, + numClasses: Int) = this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 4d30e4b5548a..dc94a1401454 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -107,7 +107,7 @@ final class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) + private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) override protected def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 1ee43c872573..db75c0d26392 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -125,7 +125,7 @@ final class RandomForestRegressionModel private[ml] ( * Construct a random forest regression model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = + private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = this(Identifiable.randomUID("rfr"), trees, numFeatures) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] From 85be65b39ce669f937a898195a844844d757666b Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 7 Aug 2015 17:21:12 -0700 Subject: [PATCH 0126/1824] [SPARK-9719] [ML] Clean up Naive Bayes doc Small documentation cleanups, including: * Adds documentation for `pi` and `theta` * setParam to `setModelType` Author: Feynman Liang Closes #8047 from feynmanliang/SPARK-9719 and squashes the following commits: b372438 [Feynman Liang] Clean up naive bayes doc --- .../scala/org/apache/spark/ml/classification/NaiveBayes.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index b46b676204e0..97cbaf1fa876 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -86,6 +86,7 @@ class NaiveBayes(override val uid: String) * Set the model type using a string (case-sensitive). * Supported options: "multinomial" and "bernoulli". * Default is "multinomial" + * @group setParam */ def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) @@ -101,6 +102,9 @@ class NaiveBayes(override val uid: String) /** * Model produced by [[NaiveBayes]] + * @param pi log of class priors, whose dimension is C (number of classes) + * @param theta log of class conditional probabilities, whose dimension is C (number of classes) + * by D (number of features) */ class NaiveBayesModel private[ml] ( override val uid: String, From 998f4ff94df1d9db1c9e32c04091017c25cd4e81 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 7 Aug 2015 19:09:28 -0700 Subject: [PATCH 0127/1824] [SPARK-9754][SQL] Remove TypeCheck in debug package. TypeCheck no longer applies in the new "Tungsten" world. Author: Reynold Xin Closes #8043 from rxin/SPARK-9754 and squashes the following commits: 4ec471e [Reynold Xin] [SPARK-9754][SQL] Remove TypeCheck in debug package. --- .../spark/sql/execution/debug/package.scala | 104 +----------------- .../sql/execution/debug/DebuggingSuite.scala | 4 - 2 files changed, 4 insertions(+), 104 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dd3858ea2b52..74892e4e13fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,21 +17,16 @@ package org.apache.spark.sql.execution -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.unsafe.types.UTF8String - import scala.collection.mutable.HashSet -import org.apache.spark.{AccumulatorParam, Accumulator, Logging} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.types._ +import org.apache.spark.{Accumulator, AccumulatorParam, Logging} /** - * :: DeveloperApi :: * Contains methods for debugging query execution. * * Usage: @@ -53,10 +48,8 @@ package object debug { } /** - * :: DeveloperApi :: * Augments [[DataFrame]]s with debug methods. */ - @DeveloperApi implicit class DebugQuery(query: DataFrame) extends Logging { def debug(): Unit = { val plan = query.queryExecution.executedPlan @@ -72,23 +65,6 @@ package object debug { case _ => } } - - def typeCheck(): Unit = { - val plan = query.queryExecution.executedPlan - val visited = new collection.mutable.HashSet[TreeNodeRef]() - val debugPlan = plan transform { - case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => - visited += new TreeNodeRef(s) - TypeCheck(s) - } - try { - logDebug(s"Results returned: ${debugPlan.execute().count()}") - } catch { - case e: Exception => - def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause) - logDebug(s"Deepest Error: ${unwrap(e)}") - } - } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { @@ -148,76 +124,4 @@ package object debug { } } } - - /** - * Helper functions for checking that runtime types match a given schema. - */ - private[sql] object TypeCheck { - def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { - case (null, _) => - - case (row: InternalRow, s: StructType) => - row.toSeq(s).zip(s.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } - case (a: ArrayData, ArrayType(elemType, _)) => - a.foreach(elemType, (_, e) => { - typeCheck(e, elemType) - }) - case (m: MapData, MapType(keyType, valueType, _)) => - m.keyArray().foreach(keyType, (_, e) => { - typeCheck(e, keyType) - }) - m.valueArray().foreach(valueType, (_, e) => { - typeCheck(e, valueType) - }) - - case (_: Long, LongType) => - case (_: Int, IntegerType) => - case (_: UTF8String, StringType) => - case (_: Float, FloatType) => - case (_: Byte, ByteType) => - case (_: Short, ShortType) => - case (_: Boolean, BooleanType) => - case (_: Double, DoubleType) => - case (_: Int, DateType) => - case (_: Long, TimestampType) => - case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) - - case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") - } - } - - /** - * Augments [[DataFrame]]s with debug methods. - */ - private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { - import TypeCheck._ - - override def nodeName: String = "" - - /* Only required when defining this class in a REPL. - override def makeCopy(args: Array[Object]): this.type = - TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type] - */ - - def output: Seq[Attribute] = child.output - - def children: List[SparkPlan] = child :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - child.execute().map { row => - try typeCheck(row, child.schema) catch { - case e: Exception => - sys.error( - s""" - |ERROR WHEN TYPE CHECKING QUERY - |============================== - |$e - |======== BAD TREE ============ - |$child - """.stripMargin) - } - row - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8ec3985e0036..239deb797384 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -25,8 +25,4 @@ class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } - - test("DataFrame.typeCheck()") { - testData.typeCheck() - } } From c564b27447ed99e55b359b3df1d586d5766b85ea Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 7 Aug 2015 20:04:17 -0700 Subject: [PATCH 0128/1824] [SPARK-9753] [SQL] TungstenAggregate should also accept InternalRow instead of just UnsafeRow https://issues.apache.org/jira/browse/SPARK-9753 This PR makes TungstenAggregate to accept `InternalRow` instead of just `UnsafeRow`. Also, it adds an `getAggregationBufferFromUnsafeRow` method to `UnsafeFixedWidthAggregationMap`. It is useful when we already have grouping keys stored in `UnsafeRow`s. Finally, it wraps `InputStream` and `OutputStream` in `UnsafeRowSerializer` with `BufferedInputStream` and `BufferedOutputStream`, respectively. Author: Yin Huai Closes #8041 from yhuai/joinedRowForProjection and squashes the following commits: 7753e34 [Yin Huai] Use BufferedInputStream and BufferedOutputStream. d68b74e [Yin Huai] Use joinedRow instead of UnsafeRowJoiner. e93c009 [Yin Huai] Add getAggregationBufferFromUnsafeRow for cases that the given groupingKeyRow is already an UnsafeRow. --- .../UnsafeFixedWidthAggregationMap.java | 4 ++ .../sql/execution/UnsafeRowSerializer.scala | 30 +++-------- .../aggregate/TungstenAggregate.scala | 4 +- .../TungstenAggregationIterator.scala | 51 +++++++++---------- 4 files changed, 39 insertions(+), 50 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index b08a4a13a28b..00218f213054 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -121,6 +121,10 @@ public UnsafeFixedWidthAggregationMap( public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); + return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow); + } + + public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) { // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( unsafeGroupingKeyRow.getBaseObject(), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 39f8f992a9f0..6c7e5cacc99e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -58,27 +58,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) - // When `out` is backed by ChainedBufferOutputStream, we will get an - // UnsupportedOperationException when we call dOut.writeInt because it internally calls - // ChainedBufferOutputStream's write(b: Int), which is not supported. - // To workaround this issue, we create an array for sorting the int value. - // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and - // run SparkSqlSerializer2SortMergeShuffleSuite. - private[this] var intBuffer: Array[Byte] = new Array[Byte](4) - private[this] val dOut: DataOutputStream = new DataOutputStream(out) + private[this] val dOut: DataOutputStream = + new DataOutputStream(new BufferedOutputStream(out)) override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] - val size = row.getSizeInBytes - // This part is based on DataOutputStream's writeInt. - // It is for dOut.writeInt(row.getSizeInBytes). - intBuffer(0) = ((size >>> 24) & 0xFF).toByte - intBuffer(1) = ((size >>> 16) & 0xFF).toByte - intBuffer(2) = ((size >>> 8) & 0xFF).toByte - intBuffer(3) = ((size >>> 0) & 0xFF).toByte - dOut.write(intBuffer, 0, 4) - - row.writeToStream(out, writeBuffer) + + dOut.writeInt(row.getSizeInBytes) + row.writeToStream(dOut, writeBuffer) this } @@ -105,7 +92,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def close(): Unit = { writeBuffer = null - intBuffer = null dOut.writeInt(EOF) dOut.close() } @@ -113,7 +99,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { - private[this] val dIn: DataInputStream = new DataInputStream(in) + private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var row: UnsafeRow = new UnsafeRow() @@ -129,7 +115,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst if (rowBuffer.length < rowSize) { rowBuffer = new Array[Byte](rowSize) } - ByteStreams.readFully(in, rowBuffer, 0, rowSize) + ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream @@ -163,7 +149,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst if (rowBuffer.length < rowSize) { rowBuffer = new Array[Byte](rowSize) } - ByteStreams.readFully(in, rowBuffer, 0, rowSize) + ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index c3dcbd2b71ee..1694794a53d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -39,7 +39,7 @@ case class TungstenAggregate( override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false + override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -77,7 +77,7 @@ case class TungstenAggregate( resultExpressions, newMutableProjection, child.output, - iter.asInstanceOf[Iterator[UnsafeRow]], + iter, testFallbackStartsAt) if (!hasInput && groupingExpressions.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 440bef32f4e9..32160906c3bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -22,6 +22,7 @@ import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.types.StructType @@ -46,8 +47,7 @@ import org.apache.spark.sql.types.StructType * processing input rows from inputIter, and generating output * rows. * - Part 3: Methods and fields used by hash-based aggregation. - * - Part 4: The function used to switch this iterator from hash-based - * aggregation to sort-based aggregation. + * - Part 4: Methods and fields used when we switch to sort-based aggregation. * - Part 5: Methods and fields used by sort-based aggregation. * - Part 6: Loads input and process input rows. * - Part 7: Public methods of this iterator. @@ -82,7 +82,7 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], - inputIter: Iterator[UnsafeRow], + inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int]) extends Iterator[UnsafeRow] with Logging { @@ -174,13 +174,10 @@ class TungstenAggregationIterator( // Creates a function used to process a row based on the given inputAttributes. private def generateProcessRow( - inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = { + inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = { val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) - val aggregationBufferSchema = StructType.fromAttributes(aggregationBufferAttributes) - val inputSchema = StructType.fromAttributes(inputAttributes) - val unsafeRowJoiner = - GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema) + val joinedRow = new JoinedRow() aggregationMode match { // Partial-only @@ -189,9 +186,9 @@ class TungstenAggregationIterator( val algebraicUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { + (currentBuffer: UnsafeRow, row: InternalRow) => { algebraicUpdateProjection.target(currentBuffer) - algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + algebraicUpdateProjection(joinedRow(currentBuffer, row)) } // PartialMerge-only or Final-only @@ -203,10 +200,10 @@ class TungstenAggregationIterator( mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { + (currentBuffer: UnsafeRow, row: InternalRow) => { // Process all algebraic aggregate functions. algebraicMergeProjection.target(currentBuffer) - algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row)) + algebraicMergeProjection(joinedRow(currentBuffer, row)) } // Final-Complete @@ -233,8 +230,8 @@ class TungstenAggregationIterator( val completeAlgebraicUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { - val input = unsafeRowJoiner.join(currentBuffer, row) + (currentBuffer: UnsafeRow, row: InternalRow) => { + val input = joinedRow(currentBuffer, row) // For all aggregate functions with mode Complete, update the given currentBuffer. completeAlgebraicUpdateProjection.target(currentBuffer)(input) @@ -253,14 +250,14 @@ class TungstenAggregationIterator( val completeAlgebraicUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - (currentBuffer: UnsafeRow, row: UnsafeRow) => { + (currentBuffer: UnsafeRow, row: InternalRow) => { completeAlgebraicUpdateProjection.target(currentBuffer) // For all aggregate functions with mode Complete, update the given currentBuffer. - completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row)) + completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row)) } // Grouping only. - case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {} + case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {} case other => throw new IllegalStateException( @@ -272,15 +269,16 @@ class TungstenAggregationIterator( private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { val groupingAttributes = groupingExpressions.map(_.toAttribute) - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) aggregationMode match { // Partial-only or PartialMerge-only: every output row is basically the values of // the grouping expressions and the corresponding aggregation buffer. case (Some(Partial), None) | (Some(PartialMerge), None) => + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { unsafeRowJoiner.join(currentGroupingKey, currentBuffer) } @@ -288,11 +286,12 @@ class TungstenAggregationIterator( // Final-only, Complete-only and Final-Complete: a output row is generated based on // resultExpressions. case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => + val joinedRow = new JoinedRow() val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(unsafeRowJoiner.join(currentGroupingKey, currentBuffer)) + resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } // Grouping-only: a output row is generated from values of grouping expressions. @@ -316,7 +315,7 @@ class TungstenAggregationIterator( // A function used to process a input row. Its first argument is the aggregation buffer // and the second argument is the input row. - private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit = + private[this] var processRow: (UnsafeRow, InternalRow) => Unit = generateProcessRow(originalInputAttributes) // A function used to generate output rows based on the grouping keys (first argument) @@ -354,7 +353,7 @@ class TungstenAggregationIterator( while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() val groupingKey = groupProjection.apply(newInput) - val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey) + val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { // buffer == null means that we could not allocate more memory. // Now, we need to spill the map and switch to sort-based aggregation. @@ -374,7 +373,7 @@ class TungstenAggregationIterator( val newInput = inputIter.next() val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = if (i < fallbackStartsAt) { - hashMap.getAggregationBuffer(groupingKey) + hashMap.getAggregationBufferFromUnsafeRow(groupingKey) } else { null } @@ -397,7 +396,7 @@ class TungstenAggregationIterator( private[this] var mapIteratorHasNext: Boolean = false /////////////////////////////////////////////////////////////////////////// - // Part 3: Methods and fields used by sort-based aggregation. + // Part 4: Methods and fields used when we switch to sort-based aggregation. /////////////////////////////////////////////////////////////////////////// // This sorter is used for sort-based aggregation. It is initialized as soon as @@ -407,7 +406,7 @@ class TungstenAggregationIterator( /** * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ - private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: UnsafeRow): Unit = { + private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() From ef062c15992b0d08554495b8ea837bef3fabf6e9 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Fri, 7 Aug 2015 23:36:26 -0700 Subject: [PATCH 0129/1824] [SPARK-9731] Standalone scheduling incorrect cores if spark.executor.cores is not set The issue only happens if `spark.executor.cores` is not set and executor memory is set to a high value. For example, if we have a worker with 4G and 10 cores and we set `spark.executor.memory` to 3G, then only 1 core is assigned to the executor. The correct number should be 10 cores. I've added a unit test to illustrate the issue. Author: Carson Wang Closes #8017 from carsonwang/SPARK-9731 and squashes the following commits: d09ec48 [Carson Wang] Fix code style 86b651f [Carson Wang] Simplify the code 943cc4c [Carson Wang] fix scheduling correct cores to executors --- .../apache/spark/deploy/master/Master.scala | 26 ++++++++++--------- .../spark/deploy/master/MasterSuite.scala | 15 +++++++++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e38e437fe1c5..9217202b69a6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -581,20 +581,22 @@ private[deploy] class Master( /** Return whether the specified worker can launch an executor for this app. */ def canLaunchExecutor(pos: Int): Boolean = { + val keepScheduling = coresToAssign >= minCoresPerExecutor + val enoughCores = usableWorkers(pos).coresFree - assignedCores(pos) >= minCoresPerExecutor + // If we allow multiple executors per worker, then we can always launch new executors. - // Otherwise, we may have already started assigning cores to the executor on this worker. + // Otherwise, if there is already an executor on this worker, just give it more cores. val launchingNewExecutor = !oneExecutorPerWorker || assignedExecutors(pos) == 0 - val underLimit = - if (launchingNewExecutor) { - assignedExecutors.sum + app.executors.size < app.executorLimit - } else { - true - } - val assignedMemory = assignedExecutors(pos) * memoryPerExecutor - usableWorkers(pos).memoryFree - assignedMemory >= memoryPerExecutor && - usableWorkers(pos).coresFree - assignedCores(pos) >= minCoresPerExecutor && - coresToAssign >= minCoresPerExecutor && - underLimit + if (launchingNewExecutor) { + val assignedMemory = assignedExecutors(pos) * memoryPerExecutor + val enoughMemory = usableWorkers(pos).memoryFree - assignedMemory >= memoryPerExecutor + val underLimit = assignedExecutors.sum + app.executors.size < app.executorLimit + keepScheduling && enoughCores && enoughMemory && underLimit + } else { + // We're adding cores to an existing executor, so no need + // to check memory and executor limits + keepScheduling && enoughCores + } } // Keep launching executors until no more workers can accommodate any diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index ae0e037d822e..20d0201a364a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -151,6 +151,14 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva basicScheduling(spreadOut = false) } + test("basic scheduling with more memory - spread out") { + basicSchedulingWithMoreMemory(spreadOut = true) + } + + test("basic scheduling with more memory - no spread out") { + basicSchedulingWithMoreMemory(spreadOut = false) + } + test("scheduling with max cores - spread out") { schedulingWithMaxCores(spreadOut = true) } @@ -214,6 +222,13 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva assert(scheduledCores === Array(10, 10, 10)) } + private def basicSchedulingWithMoreMemory(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(3072) + val scheduledCores = scheduleExecutorsOnWorkers(master, appInfo, workerInfos, spreadOut) + assert(scheduledCores === Array(10, 10, 10)) + } + private def schedulingWithMaxCores(spreadOut: Boolean): Unit = { val master = makeMaster() val appInfo1 = makeAppInfo(1024, maxCores = Some(8)) From 11caf1ce290b6931647c2f71268f847d1d48930e Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 8 Aug 2015 18:09:48 +0800 Subject: [PATCH 0130/1824] [SPARK-4176] [SQL] [MINOR] Should use unscaled Long to write decimals for precision <= 18 rather than 8 This PR fixes a minor bug introduced in #7455: when writing decimals, we should use the unscaled Long for better performance when the precision <= 18 rather than 8 (should be a typo). This bug doesn't affect correctness, but hurts Parquet decimal writing performance. This PR also replaced similar magic numbers with newly defined constants. Author: Cheng Lian Closes #8031 from liancheng/spark-4176/minor-fix-for-writing-decimals and squashes the following commits: 10d4ea3 [Cheng Lian] Should use unscaled Long to write decimals for precision <= 18 rather than 8 --- .../sql/parquet/CatalystRowConverter.scala | 2 +- .../sql/parquet/CatalystSchemaConverter.scala | 29 +++++++++++-------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index 6938b071065c..4fe8a39f20ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -264,7 +264,7 @@ private[parquet] class CatalystRowConverter( val scale = decimalType.scale val bytes = value.getBytes - if (precision <= 8) { + if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. var unscaled = 0L var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index d43ca95b4eea..b12149dcf1c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -25,6 +25,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.parquet.schema._ +import org.apache.spark.sql.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLConf} @@ -155,7 +156,7 @@ private[parquet] class CatalystSchemaConverter( case INT_16 => ShortType case INT_32 | null => IntegerType case DATE => DateType - case DECIMAL => makeDecimalType(maxPrecisionForBytes(4)) + case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) case TIME_MILLIS => typeNotImplemented() case _ => illegalType() } @@ -163,7 +164,7 @@ private[parquet] class CatalystSchemaConverter( case INT64 => originalType match { case INT_64 | null => LongType - case DECIMAL => makeDecimalType(maxPrecisionForBytes(8)) + case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() } @@ -405,7 +406,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec => + if precision <= MAX_PRECISION_FOR_INT32 && followParquetFormatSpec => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -415,7 +416,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec => + if precision <= MAX_PRECISION_FOR_INT64 && followParquetFormatSpec => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -534,14 +535,6 @@ private[parquet] class CatalystSchemaConverter( throw new AnalysisException(s"Unsupported data type $field.dataType") } } - - // Max precision of a decimal value stored in `numBytes` bytes - private def maxPrecisionForBytes(numBytes: Int): Int = { - Math.round( // convert double to long - Math.floor(Math.log10( // number of base-10 digits - Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes - .asInstanceOf[Int] - } } @@ -584,4 +577,16 @@ private[parquet] object CatalystSchemaConverter { computeMinBytesForPrecision(precision) } } + + val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) + + val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) + + // Max precision of a decimal value stored in `numBytes` bytes + def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] + } } From 106c0789d8c83c7081bc9a335df78ba728e95872 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 8 Aug 2015 08:33:14 -0700 Subject: [PATCH 0131/1824] [SPARK-9738] [SQL] remove FromUnsafe and add its codegen version to GenerateSafe In https://github.com/apache/spark/pull/7752 we added `FromUnsafe` to convert nexted unsafe data like array/map/struct to safe versions. It's a quick solution and we already have `GenerateSafe` to do the conversion which is codegened. So we should remove `FromUnsafe` and implement its codegen version in `GenerateSafe`. Author: Wenchen Fan Closes #8029 from cloud-fan/from-unsafe and squashes the following commits: ed40d8f [Wenchen Fan] add the copy back a93fd4b [Wenchen Fan] cogengen FromUnsafe --- .../sql/catalyst/expressions/FromUnsafe.scala | 70 ---------- .../sql/catalyst/expressions/Projection.scala | 8 +- .../codegen/GenerateSafeProjection.scala | 120 +++++++++++++----- .../execution/RowFormatConvertersSuite.scala | 4 +- 4 files changed, 95 insertions(+), 107 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala deleted file mode 100644 index 9b960b136f98..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -case class FromUnsafe(child: Expression) extends UnaryExpression - with ExpectsInputTypes with CodegenFallback { - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(ArrayType, StructType, MapType)) - - override def dataType: DataType = child.dataType - - private def convert(value: Any, dt: DataType): Any = dt match { - case StructType(fields) => - val row = value.asInstanceOf[UnsafeRow] - val result = new Array[Any](fields.length) - fields.map(_.dataType).zipWithIndex.foreach { case (dt, i) => - if (!row.isNullAt(i)) { - result(i) = convert(row.get(i, dt), dt) - } - } - new GenericInternalRow(result) - - case ArrayType(elementType, _) => - val array = value.asInstanceOf[UnsafeArrayData] - val length = array.numElements() - val result = new Array[Any](length) - var i = 0 - while (i < length) { - if (!array.isNullAt(i)) { - result(i) = convert(array.get(i, elementType), elementType) - } - i += 1 - } - new GenericArrayData(result) - - case StringType => value.asInstanceOf[UTF8String].clone() - - case MapType(kt, vt, _) => - val map = value.asInstanceOf[UnsafeMapData] - val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData] - val safeValueArray = convert(map.values, ArrayType(vt)).asInstanceOf[GenericArrayData] - new ArrayBasedMapData(safeKeyArray, safeValueArray) - - case _ => value - } - - override def nullSafeEval(input: Any): Any = { - convert(input, dataType) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 796bc327a3db..afe52e6a667e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -152,13 +152,7 @@ object FromUnsafeProjection { */ def apply(fields: Seq[DataType]): Projection = { create(fields.zipWithIndex.map(x => { - val b = new BoundReference(x._2, x._1, true) - // todo: this is quite slow, maybe remove this whole projection after remove generic getter of - // InternalRow? - b.dataType match { - case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b) - case _ => b - } + new BoundReference(x._2, x._1, true) })) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f06ffc5449e7..ef08ddf041af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.types.{StringType, StructType, DataType} +import org.apache.spark.sql.types._ /** @@ -36,34 +36,94 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) - private def genUpdater( + private def createCodeForStruct( ctx: CodeGenContext, - setter: String, - dataType: DataType, - ordinal: Int, - value: String): String = { - dataType match { - case struct: StructType => - val rowTerm = ctx.freshName("row") - val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) => - val colTerm = ctx.freshName("col") - s""" - if ($value.isNullAt($i)) { - $rowTerm.setNullAt($i); - } else { - ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")}; - ${genUpdater(ctx, rowTerm, dt, i, colTerm)}; - } - """ - }.mkString("\n") - s""" - $genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length}); - $updates - $setter.update($ordinal, $rowTerm.copy()); - """ - case _ => - ctx.setColumn(setter, dataType, ordinal, value) - } + input: String, + schema: StructType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeRow") + val values = ctx.freshName("values") + val rowClass = classOf[GenericInternalRow].getName + + val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => + val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt) + s""" + if (!$tmp.isNullAt($i)) { + ${converter.code} + $values[$i] = ${converter.primitive}; + } + """ + }.mkString("\n") + + val code = s""" + final InternalRow $tmp = $input; + final Object[] $values = new Object[${schema.length}]; + $fieldWriters + final InternalRow $output = new $rowClass($values); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def createCodeForArray( + ctx: CodeGenContext, + input: String, + elementType: DataType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeArray") + val values = ctx.freshName("values") + val numElements = ctx.freshName("numElements") + val index = ctx.freshName("index") + val arrayClass = classOf[GenericArrayData].getName + + val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType) + val code = s""" + final ArrayData $tmp = $input; + final int $numElements = $tmp.numElements(); + final Object[] $values = new Object[$numElements]; + for (int $index = 0; $index < $numElements; $index++) { + if (!$tmp.isNullAt($index)) { + ${elementConverter.code} + $values[$index] = ${elementConverter.primitive}; + } + } + final ArrayData $output = new $arrayClass($values); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def createCodeForMap( + ctx: CodeGenContext, + input: String, + keyType: DataType, + valueType: DataType): GeneratedExpressionCode = { + val tmp = ctx.freshName("tmp") + val output = ctx.freshName("safeMap") + val mapClass = classOf[ArrayBasedMapData].getName + + val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType) + val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType) + val code = s""" + final MapData $tmp = $input; + ${keyConverter.code} + ${valueConverter.code} + final MapData $output = new $mapClass(${keyConverter.primitive}, ${valueConverter.primitive}); + """ + + GeneratedExpressionCode(code, "false", output) + } + + private def convertToSafe( + ctx: CodeGenContext, + input: String, + dataType: DataType): GeneratedExpressionCode = dataType match { + case s: StructType => createCodeForStruct(ctx, input, s) + case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) + case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) + // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. + case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case _ => GeneratedExpressionCode("", "false", input) } protected def create(expressions: Seq[Expression]): Projection = { @@ -72,12 +132,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) + val converter = convertToSafe(ctx, evaluationCode.primitive, e.dataType) evaluationCode.code + s""" if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); } else { - ${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)}; + ${converter.code} + ${ctx.setColumn("mutableRow", e.dataType, i, converter.primitive)}; } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 322966f42378..dd08e9025a92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -112,7 +112,9 @@ case class DummyPlan(child: SparkPlan) extends UnaryNode { override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => - // cache all strings to make sure we have deep copied UTF8String inside incoming + // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some + // values gotten from the incoming rows. + // we cache all strings here to make sure we have deep copied UTF8String inside incoming // safe InternalRow. val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] iter.foreach { row => From 74a6541aa82bcd7a052b2e57b5ca55b7c316495b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 8 Aug 2015 08:36:14 -0700 Subject: [PATCH 0132/1824] [SPARK-4561] [PYSPARK] [SQL] turn Row into dict recursively Add an option `recursive` to `Row.asDict()`, when True (default is False), it will convert the nested Row into dict. Author: Davies Liu Closes #8006 from davies/as_dict and squashes the following commits: 922cc5a [Davies Liu] turn Row into dict recursively --- python/pyspark/sql/types.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 6f74b7162f7c..e2e6f03ae9fd 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1197,13 +1197,36 @@ def __new__(self, *args, **kwargs): else: raise ValueError("No args or kwargs") - def asDict(self): + def asDict(self, recursive=False): """ Return as an dict + + :param recursive: turns the nested Row as dict (default: False). + + >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} + True + >>> row = Row(key=1, value=Row(name='a', age=2)) + >>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')} + True + >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} + True """ if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__fields__, self)) + + if recursive: + def conv(obj): + if isinstance(obj, Row): + return obj.asDict(True) + elif isinstance(obj, list): + return [conv(o) for o in obj] + elif isinstance(obj, dict): + return dict((k, conv(v)) for k, v in obj.items()) + else: + return obj + return dict(zip(self.__fields__, (conv(o) for o in self))) + else: + return dict(zip(self.__fields__, self)) # let object acts like class def __call__(self, *args): From ac507a03c3371cd5404ca195ee0ba0306badfc23 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 8 Aug 2015 08:38:18 -0700 Subject: [PATCH 0133/1824] [SPARK-6902] [SQL] [PYSPARK] Row should be read-only Raise an read-only exception when user try to mutable a Row. Author: Davies Liu Closes #8009 from davies/readonly_row and squashes the following commits: 8722f3f [Davies Liu] add tests 05a3d36 [Davies Liu] Row should be read-only --- python/pyspark/sql/tests.py | 15 +++++++++++++++ python/pyspark/sql/types.py | 5 +++++ 2 files changed, 20 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1e3444dd9e3b..38c83c427a74 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -179,6 +179,21 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_row_should_be_read_only(self): + row = Row(a=1, b=2) + self.assertEqual(1, row.a) + + def foo(): + row.a = 3 + self.assertRaises(Exception, foo) + + row2 = self.sqlCtx.range(10).first() + self.assertEqual(0, row2.id) + + def foo2(): + row2.id = 2 + self.assertRaises(Exception, foo2) + def test_range(self): self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index e2e6f03ae9fd..c083bf89905b 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1246,6 +1246,11 @@ def __getattr__(self, item): except ValueError: raise AttributeError(item) + def __setattr__(self, key, value): + if key != '__fields__': + raise Exception("Row is read-only") + self.__dict__[key] = value + def __reduce__(self): """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): From 23695f1d2d7ef9f3ea92cebcd96b1cf0e8904eb4 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 8 Aug 2015 11:01:25 -0700 Subject: [PATCH 0134/1824] [SPARK-9728][SQL]Support CalendarIntervalType in HiveQL This PR enables converting interval term in HiveQL to CalendarInterval Literal. JIRA: https://issues.apache.org/jira/browse/SPARK-9728 Author: Yijie Shen Closes #8034 from yjshen/interval_hiveql and squashes the following commits: 7fe9a5e [Yijie Shen] declare throw exception and add unit test fce7795 [Yijie Shen] convert hiveql interval term into CalendarInterval literal --- .../org/apache/spark/sql/hive/HiveQl.scala | 25 +++ .../apache/spark/sql/hive/HiveQlSuite.scala | 15 ++ .../sql/hive/execution/SQLQuerySuite.scala | 22 +++ .../spark/unsafe/types/CalendarInterval.java | 156 ++++++++++++++++++ .../unsafe/types/CalendarIntervalSuite.java | 91 ++++++++++ 5 files changed, 309 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7d7b4b916730..c3f29350101d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler /* Implicit conversions */ @@ -1519,6 +1520,30 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) + + case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index f765395e148a..79cf40aba4bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -175,4 +175,19 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(desc.serde == Option("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + HiveQl.parseSql(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1dff07a6de8a..2fa7ae3fa2e1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval case class Nested1(f1: Nested2) case class Nested2(f2: Nested3) @@ -1115,4 +1116,25 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("Convert hive interval term into Literal of CalendarIntervalType") { + checkAnswer(sql("select interval '10-9' year to month"), + Row(CalendarInterval.fromString("interval 10 years 9 months"))) + checkAnswer(sql("select interval '20 15:40:32.99899999' day to second"), + Row(CalendarInterval.fromString("interval 2 weeks 6 days 15 hours 40 minutes " + + "32 seconds 99 milliseconds 899 microseconds"))) + checkAnswer(sql("select interval '30' year"), + Row(CalendarInterval.fromString("interval 30 years"))) + checkAnswer(sql("select interval '25' month"), + Row(CalendarInterval.fromString("interval 25 months"))) + checkAnswer(sql("select interval '-100' day"), + Row(CalendarInterval.fromString("interval -14 weeks -2 days"))) + checkAnswer(sql("select interval '40' hour"), + Row(CalendarInterval.fromString("interval 1 days 16 hours"))) + checkAnswer(sql("select interval '80' minute"), + Row(CalendarInterval.fromString("interval 1 hour 20 minutes"))) + checkAnswer(sql("select interval '299.889987299' second"), + Row(CalendarInterval.fromString( + "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 92a5e4f86f23..30e175807636 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -50,6 +50,14 @@ private static String unitRegex(String unit) { unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond")); + private static Pattern yearMonthPattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$"); + + private static Pattern dayTimePattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+) (\\d+):(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$"); + + private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$"); + private static long toLong(String s) { if (s == null) { return 0; @@ -79,6 +87,154 @@ public static CalendarInterval fromString(String s) { } } + public static long toLongWithRange(String fieldName, + String s, long minValue, long maxValue) throws IllegalArgumentException { + long result = 0; + if (s != null) { + result = Long.valueOf(s); + if (result < minValue || result > maxValue) { + throw new IllegalArgumentException(String.format("%s %d outside range [%d, %d]", + fieldName, result, minValue, maxValue)); + } + } + return result; + } + + /** + * Parse YearMonth string in form: [-]YYYY-MM + * + * adapted from HiveIntervalYearMonth.valueOf + */ + public static CalendarInterval fromYearMonthString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval year-month string was null"); + } + s = s.trim(); + Matcher m = yearMonthPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match year-month format of 'y-m': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + int years = (int) toLongWithRange("year", m.group(2), 0, Integer.MAX_VALUE); + int months = (int) toLongWithRange("month", m.group(3), 0, 11); + result = new CalendarInterval(sign * (years * 12 + months), 0); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval year-month string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn + * + * adapted from HiveIntervalDayTime.valueOf + */ + public static CalendarInterval fromDayTimeString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval day-time string was null"); + } + s = s.trim(); + Matcher m = dayTimePattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + long days = toLongWithRange("day", m.group(2), 0, Integer.MAX_VALUE); + long hours = toLongWithRange("hour", m.group(3), 0, 23); + long minutes = toLongWithRange("minute", m.group(4), 0, 59); + long seconds = toLongWithRange("second", m.group(5), 0, 59); + // Hive allow nanosecond precision interval + long nanos = toLongWithRange("nanosecond", m.group(7), 0L, 999999999L); + result = new CalendarInterval(0, sign * ( + days * MICROS_PER_DAY + hours * MICROS_PER_HOUR + minutes * MICROS_PER_MINUTE + + seconds * MICROS_PER_SECOND + nanos / 1000L)); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval day-time string: " + e.getMessage(), e); + } + } + return result; + } + + public static CalendarInterval fromSingleUnitString(String unit, String s) + throws IllegalArgumentException { + + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException(String.format("Interval %s string was null", unit)); + } + s = s.trim(); + Matcher m = quoteTrimPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + if (unit.equals("year")) { + int year = (int) toLongWithRange("year", m.group(1), + Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12); + result = new CalendarInterval(year * 12, 0L); + + } else if (unit.equals("month")) { + int month = (int) toLongWithRange("month", m.group(1), + Integer.MIN_VALUE, Integer.MAX_VALUE); + result = new CalendarInterval(month, 0L); + + } else if (unit.equals("day")) { + long day = toLongWithRange("day", m.group(1), + Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); + result = new CalendarInterval(0, day * MICROS_PER_DAY); + + } else if (unit.equals("hour")) { + long hour = toLongWithRange("hour", m.group(1), + Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR); + result = new CalendarInterval(0, hour * MICROS_PER_HOUR); + + } else if (unit.equals("minute")) { + long minute = toLongWithRange("minute", m.group(1), + Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE); + result = new CalendarInterval(0, minute * MICROS_PER_MINUTE); + + } else if (unit.equals("second")) { + long micros = parseSecondNano(m.group(1)); + result = new CalendarInterval(0, micros); + } + } catch (Exception e) { + throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse second_nano string in ss.nnnnnnnnn format to microseconds + */ + public static long parseSecondNano(String secondNano) throws IllegalArgumentException { + String[] parts = secondNano.split("\\."); + if (parts.length == 1) { + return toLongWithRange("second", parts[0], Long.MIN_VALUE / MICROS_PER_SECOND, + Long.MAX_VALUE / MICROS_PER_SECOND) * MICROS_PER_SECOND; + + } else if (parts.length == 2) { + long seconds = parts[0].equals("") ? 0L : toLongWithRange("second", parts[0], + Long.MIN_VALUE / MICROS_PER_SECOND, Long.MAX_VALUE / MICROS_PER_SECOND); + long nanos = toLongWithRange("nanosecond", parts[1], 0L, 999999999L); + return seconds * MICROS_PER_SECOND + nanos / 1000L; + + } else { + throw new IllegalArgumentException( + "Interval string does not match second-nano format of ss.nnnnnnnnn"); + } + } + public final int months; public final long microseconds; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 6274b92b47dd..80d4982c4b57 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -101,6 +101,97 @@ public void fromStringTest() { assertEquals(CalendarInterval.fromString(input), null); } + @Test + public void fromYearMonthStringTest() { + String input; + CalendarInterval i; + + input = "99-10"; + i = new CalendarInterval(99 * 12 + 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + input = "-8-10"; + i = new CalendarInterval(-8 * 12 - 10, 0L); + assertEquals(CalendarInterval.fromYearMonthString(input), i); + + try { + input = "99-15"; + CalendarInterval.fromYearMonthString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("month 15 outside range")); + } + } + + @Test + public void fromDayTimeStringTest() { + String input; + CalendarInterval i; + + input = "5 12:40:30.999999999"; + i = new CalendarInterval(0, 5 * MICROS_PER_DAY + 12 * MICROS_PER_HOUR + + 40 * MICROS_PER_MINUTE + 30 * MICROS_PER_SECOND + 999999L); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "10 0:12:0.888"; + i = new CalendarInterval(0, 10 * MICROS_PER_DAY + 12 * MICROS_PER_MINUTE); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + input = "-3 0:0:0"; + i = new CalendarInterval(0, -3 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromDayTimeString(input), i); + + try { + input = "5 30:12:20"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("hour 30 outside range")); + } + + try { + input = "5 30-12"; + CalendarInterval.fromDayTimeString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("not match day-time format")); + } + } + + @Test + public void fromSingleUnitStringTest() { + String input; + CalendarInterval i; + + input = "12"; + i = new CalendarInterval(12 * 12, 0L); + assertEquals(CalendarInterval.fromSingleUnitString("year", input), i); + + input = "100"; + i = new CalendarInterval(0, 100 * MICROS_PER_DAY); + assertEquals(CalendarInterval.fromSingleUnitString("day", input), i); + + input = "1999.38888"; + i = new CalendarInterval(0, 1999 * MICROS_PER_SECOND + 38); + assertEquals(CalendarInterval.fromSingleUnitString("second", input), i); + + try { + input = String.valueOf(Integer.MAX_VALUE); + CalendarInterval.fromSingleUnitString("year", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + + try { + input = String.valueOf(Long.MAX_VALUE / MICROS_PER_HOUR + 1); + CalendarInterval.fromSingleUnitString("hour", input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("outside range")); + } + } + @Test public void addTest() { String input = "interval 3 month 1 hour"; From a3aec918bed22f8e33cf91dc0d6e712e6653c7d2 Mon Sep 17 00:00:00 2001 From: Joseph Batchik Date: Sat, 8 Aug 2015 11:03:01 -0700 Subject: [PATCH 0135/1824] [SPARK-9486][SQL] Add data source aliasing for external packages Users currently have to provide the full class name for external data sources, like: `sqlContext.read.format("com.databricks.spark.avro").load(path)` This allows external data source packages to register themselves using a Service Loader so that they can add custom alias like: `sqlContext.read.format("avro").load(path)` This makes it so that using external data source packages uses the same format as the internal data sources like parquet, json, etc. Author: Joseph Batchik Author: Joseph Batchik Closes #7802 from JDrit/service_loader and squashes the following commits: 49a01ec [Joseph Batchik] fixed a couple of format / error bugs e5e93b2 [Joseph Batchik] modified rat file to only excluded added services 72b349a [Joseph Batchik] fixed error with orc data source actually 9f93ea7 [Joseph Batchik] fixed error with orc data source 87b7f1c [Joseph Batchik] fixed typo 101cd22 [Joseph Batchik] removing unneeded changes 8f3cf43 [Joseph Batchik] merged in changes b63d337 [Joseph Batchik] merged in master 95ae030 [Joseph Batchik] changed the new trait to be used as a mixin for data source to register themselves 74db85e [Joseph Batchik] reformatted class loader ac2270d [Joseph Batchik] removing some added test a6926db [Joseph Batchik] added test cases for data source loader 208a2a8 [Joseph Batchik] changes to do error catching if there are multiple data sources 946186e [Joseph Batchik] started working on service loader --- .rat-excludes | 1 + ...pache.spark.sql.sources.DataSourceRegister | 3 + .../spark/sql/execution/datasources/ddl.scala | 52 ++++++------ .../apache/spark/sql/jdbc/JDBCRelation.scala | 5 +- .../apache/spark/sql/json/JSONRelation.scala | 5 +- .../spark/sql/parquet/ParquetRelation.scala | 5 +- .../apache/spark/sql/sources/interfaces.scala | 21 +++++ ...pache.spark.sql.sources.DataSourceRegister | 3 + .../sql/sources/DDLSourceLoadSuite.scala | 85 +++++++++++++++++++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../spark/sql/hive/orc/OrcRelation.scala | 5 +- 11 files changed, 156 insertions(+), 30 deletions(-) create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala create mode 100644 sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister diff --git a/.rat-excludes b/.rat-excludes index 236c2db05367..72771465846b 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -93,3 +93,4 @@ INDEX .lintr gen-java.* .*avpr +org.apache.spark.sql.sources.DataSourceRegister diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..cc32d4b72748 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.jdbc.DefaultSource +org.apache.spark.sql.json.DefaultSource +org.apache.spark.sql.parquet.DefaultSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 0cdb407ad57b..8c2f297e4245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,7 +17,12 @@ package org.apache.spark.sql.execution.datasources +import java.util.ServiceLoader + +import scala.collection.Iterator +import scala.collection.JavaConversions._ import scala.language.{existentials, implicitConversions} +import scala.util.{Failure, Success, Try} import scala.util.matching.Regex import org.apache.hadoop.fs.Path @@ -190,37 +195,32 @@ private[sql] class DDLParser( } } -private[sql] object ResolvedDataSource { - - private val builtinSources = Map( - "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", - "json" -> "org.apache.spark.sql.json.DefaultSource", - "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", - "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" - ) +private[sql] object ResolvedDataSource extends Logging { /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String): Class[_] = { + val provider2 = s"$provider.DefaultSource" val loader = Utils.getContextOrSparkClassLoader - - if (builtinSources.contains(provider)) { - return loader.loadClass(builtinSources(provider)) - } - - try { - loader.loadClass(provider) - } catch { - case cnf: java.lang.ClassNotFoundException => - try { - loader.loadClass(provider + ".DefaultSource") - } catch { - case cnf: java.lang.ClassNotFoundException => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - sys.error("The ORC data source must be used with Hive support enabled.") - } else { - sys.error(s"Failed to load class for data source: $provider") - } + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match { + /** the provider format did not match any given registered aliases */ + case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => dataSource + case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + throw new ClassNotFoundException( + s"Failed to load class for data source: $provider", error) } + } + /** there is exactly one registered alias */ + case head :: Nil => head.getClass + /** There are multiple registered aliases for the input */ + case sources => sys.error(s"Multiple sources found for $provider, " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 41d0ecb4bbfb..48d97ced9ca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -77,7 +77,10 @@ private[sql] object JDBCRelation { } } -private[sql] class DefaultSource extends RelationProvider { +private[sql] class DefaultSource extends RelationProvider with DataSourceRegister { + + def format(): String = "jdbc" + /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 10f1367e6984..b34a272ec547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -37,7 +37,10 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "json" + override def createRelation( sqlContext: SQLContext, paths: Array[String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 48009b2fd007..b6db71b5b8a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -49,7 +49,10 @@ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "parquet" + override def createRelation( sqlContext: SQLContext, paths: Array[String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c5b7ee73eb78..4aafec0e2df2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -37,6 +37,27 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration +/** + * ::DeveloperApi:: + * Data sources should implement this trait so that they can register an alias to their data source. + * This allows users to give the data source alias as the format type over the fully qualified + * class name. + * + * ex: parquet.DefaultSource.format = "parquet". + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source, + * ex: override def format(): String = "parquet" + */ + def format(): String +} + /** * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..cfd7889b4ac2 --- /dev/null +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.sources.FakeSourceOne +org.apache.spark.sql.sources.FakeSourceTwo +org.apache.spark.sql.sources.FakeSourceThree diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala new file mode 100644 index 000000000000..1a4d41b02ca6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -0,0 +1,85 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +class FakeSourceOne extends RelationProvider with DataSourceRegister { + + def format(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceTwo extends RelationProvider with DataSourceRegister { + + def format(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceThree extends RelationProvider with DataSourceRegister { + + def format(): String = "gathering quorum" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("Loading Orc") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000000..4a774fbf1fdf --- /dev/null +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.orc.DefaultSource diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 7c8704b47f28..0c344c63fde3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -47,7 +47,10 @@ import org.apache.spark.util.SerializableConfiguration /* Implicit conversions */ import scala.collection.JavaConversions._ -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "orc" + def createRelation( sqlContext: SQLContext, paths: Array[String], From 25c363e93bc79119c5ba5c228fcad620061cff62 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sat, 8 Aug 2015 18:22:46 -0700 Subject: [PATCH 0136/1824] [MINOR] inaccurate comments for showString() Author: CodingCat Closes #8050 from CodingCat/minor and squashes the following commits: 5bc4b89 [CodingCat] inaccurate comments --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 405b5a4a9a7f..570b8b2d5928 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -168,7 +168,7 @@ class DataFrame private[sql]( } /** - * Internal API for Python + * Compose the string representing rows for output * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ From 3ca995b78f373251081f6877623649bfba3040b2 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 8 Aug 2015 21:05:50 -0700 Subject: [PATCH 0137/1824] [SPARK-6212] [SQL] The EXPLAIN output of CTAS only shows the analyzed plan JIRA: https://issues.apache.org/jira/browse/SPARK-6212 Author: Yijie Shen Closes #7986 from yjshen/ctas_explain and squashes the following commits: bb6fee5 [Yijie Shen] refine test f731041 [Yijie Shen] address comment b2cf8ab [Yijie Shen] bug fix bd7eb20 [Yijie Shen] ctas explain --- .../apache/spark/sql/execution/commands.scala | 2 ++ .../hive/execution/CreateTableAsSelect.scala | 4 ++- .../sql/hive/execution/HiveExplainSuite.scala | 35 +++++++++++++++++-- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 6b83025d5a15..95209e663451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -69,6 +69,8 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan val converted = sideEffectResult.map(convert(_).asInstanceOf[InternalRow]) sqlContext.sparkContext.parallelize(converted, 1) } + + override def argString: String = cmd.toString } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 84358cb73c9e..8422287e177e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -40,6 +40,8 @@ case class CreateTableAsSelect( def database: String = tableDesc.database def tableName: String = tableDesc.name + override def children: Seq[LogicalPlan] = Seq(query) + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { @@ -91,6 +93,6 @@ case class CreateTableAsSelect( } override def argString: String = { - s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]\n" + query.toString + s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]" } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 8215dd6c2e71..44c5b80392fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.test.SQLTestUtils /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest { +class HiveExplainSuite extends QueryTest with SQLTestUtils { + + def sqlContext: SQLContext = TestHive + test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, "== Physical Plan ==") @@ -74,4 +79,30 @@ class HiveExplainSuite extends QueryTest { "Limit", "src") } + + test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { + withTempTable("jt") { + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) + read.json(rdd).registerTempTable("jt") + val outputs = sql( + s""" + |EXPLAIN EXTENDED + |CREATE TABLE t1 + |AS + |SELECT * FROM jt + """.stripMargin).collect().map(_.mkString).mkString + + val shouldContain = + "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: + "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: + "CreateTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil + for (key <- shouldContain) { + assert(outputs.contains(key), s"$key doesn't exist in result") + } + + val physicalIndex = outputs.indexOf("== Physical Plan ==") + assert(!outputs.substring(physicalIndex).contains("Subquery"), + "Physical Plan should not contain Subquery since it's eliminated by optimizer") + } + } } From e9c36938ba972b6fe3c9f6228508e3c9f1c876b2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 9 Aug 2015 10:58:36 -0700 Subject: [PATCH 0138/1824] [SPARK-9752][SQL] Support UnsafeRow in Sample operator. In order for this to work, I had to disable gap sampling. Author: Reynold Xin Closes #8040 from rxin/SPARK-9752 and squashes the following commits: f9e248c [Reynold Xin] Fix the test case for real this time. adbccb3 [Reynold Xin] Fixed test case. 589fb23 [Reynold Xin] Merge branch 'SPARK-9752' of github.com:rxin/spark into SPARK-9752 55ccddc [Reynold Xin] Fixed core test. 78fa895 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator. c9e7112 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator. --- .../spark/util/random/RandomSampler.scala | 18 ++++++---- .../spark/sql/execution/basicOperators.scala | 18 +++++++--- .../apache/spark/sql/DataFrameStatSuite.scala | 35 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 17 --------- 4 files changed, 61 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 786b97ad7b9e..c156b03cdb7c 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -176,10 +176,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T * A sampler for sampling with replacement, based on values drawn from Poisson distribution. * * @param fraction the sampling fraction (with replacement) + * @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low. * @tparam T item type */ @DeveloperApi -class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] { +class PoissonSampler[T: ClassTag]( + fraction: Double, + useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] { + + def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true) /** Epsilon slop to avoid failure from floating point jitter. */ require( @@ -199,17 +204,18 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] override def sample(items: Iterator[T]): Iterator[T] = { if (fraction <= 0.0) { Iterator.empty - } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { - new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) + } else if (useGapSamplingIfPossible && + fraction <= RandomSampler.defaultMaxGapSamplingFraction) { + new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) } else { - items.flatMap { item => { + items.flatMap { item => val count = rng.sample() if (count == 0) Iterator.empty else Iterator.fill(count)(item) - }} + } } } - override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction) + override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, useGapSamplingIfPossible) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 0680f31d40f6..c5d1ed0937b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -30,6 +30,7 @@ import org.apache.spark.sql.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator +import org.apache.spark.util.random.PoissonSampler import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.{HashPartitioner, SparkEnv} @@ -130,12 +131,21 @@ case class Sample( { override def output: Seq[Attribute] = child.output - // TODO: How to pick seed? + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { - child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed) + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + new PartitionwiseSampledRDD[InternalRow, InternalRow]( + child.execute(), + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), + preservesPartitioning = true, + seed) } else { - child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed) + child.execute().randomSampleWithRange(lowerBound, upperBound, seed) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0e7659f443ec..8f5984e4a8ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -30,6 +30,41 @@ class DataFrameStatSuite extends QueryTest { private def toLetter(i: Int): String = (i + 97).toChar.toString + test("sample with replacement") { + val n = 100 + val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + Seq(5, 10, 52, 73).map(Row(_)) + ) + } + + test("sample without replacement") { + val n = 100 + val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + Seq(16, 23, 88, 100).map(Row(_)) + ) + } + + test("randomSplit") { + val n = 600 + val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array[Double](1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.collect().toList, "incomplete or wrong split") + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f9cc6d1f3c25..0212637a829e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -415,23 +415,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("randomSplit") { - val n = 600 - val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id") - for (seed <- 1 to 5) { - val splits = data.randomSplit(Array[Double](1, 2, 3), seed) - assert(splits.length == 3, "wrong number of splits") - - assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == - data.collect().toList, "incomplete or wrong split") - - val s = splits.map(_.count()) - assert(math.abs(s(0) - 100) < 50) // std = 9.13 - assert(math.abs(s(1) - 200) < 50) // std = 11.55 - assert(math.abs(s(2) - 300) < 50) // std = 12.25 - } - } - test("describe") { val describeTestData = Seq( ("Bob", 16, 176), From 68ccc6e184598822b19a880fdd4597b66a1c2d92 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 9 Aug 2015 11:44:51 -0700 Subject: [PATCH 0139/1824] [SPARK-8930] [SQL] Throw a AnalysisException with meaningful messages if DataFrame#explode takes a star in expressions Author: Yijie Shen Closes #8057 from yjshen/explode_star and squashes the following commits: eae181d [Yijie Shen] change explaination message 54c9d11 [Yijie Shen] meaning message for * in explode --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 +++- .../sql/catalyst/analysis/AnalysisTest.scala | 4 +++- .../org/apache/spark/sql/DataFrameSuite.scala | 15 +++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 82158e61e3fb..a684dbc3afa4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -408,7 +408,7 @@ class Analyzer( /** * Returns true if `exprs` contains a [[Star]]. */ - protected def containsStar(exprs: Seq[Expression]): Boolean = + def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) } @@ -602,6 +602,8 @@ class Analyzer( */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case g: Generate if ResolveReferences.containsStar(g.generator.children) => + failAnalysis("Cannot explode *, explode can only be applied on a specific column.") case p: Generate if !p.child.resolved || !p.generator.resolved => p case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index ee1f8f54251e..53b3695a86be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -71,6 +71,8 @@ trait AnalysisTest extends PlanTest { val e = intercept[Exception] { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - expectedErrors.forall(e.getMessage.contains) + assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains), + s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " + + s"actually we get ${e.getMessage}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0212637a829e..c49f256be550 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -134,6 +134,21 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-8930: explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") + val e = intercept[AnalysisException] { + df.explode($"*") { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + assert(e.getMessage.contains( + "Cannot explode *, explode can only be applied on a specific column.")) + + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + test("explode alias and star") { val df = Seq((Array("a"), 1)).toDF("a", "b") From 86fa4ba6d13f909cb508b7cb3b153d586fe59bc3 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Sun, 9 Aug 2015 19:54:05 +0100 Subject: [PATCH 0140/1824] [SPARK-9737] [YARN] Add the suggested configuration when required executor memory is above the max threshold of this cluster on YARN mode Author: Yadong Qi Closes #8028 from watermen/SPARK-9737 and squashes the following commits: 48bdf3d [Yadong Qi] Add suggested configuration. --- .../main/scala/org/apache/spark/deploy/yarn/Client.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index fc11bbf97e2e..b4ba3f022160 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -203,12 +203,14 @@ private[spark] class Client( val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + - s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + - s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( amMem, From a863348fd85848e0d4325c4de359da12e5f548d2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 9 Aug 2015 13:43:31 -0700 Subject: [PATCH 0141/1824] Disable JobGeneratorSuite "Do not clear received block data too soon". --- .../apache/spark/streaming/scheduler/JobGeneratorSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala index a2dbae149f31..9b6cd4bc4e31 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/JobGeneratorSuite.scala @@ -56,7 +56,8 @@ class JobGeneratorSuite extends TestSuiteBase { // 4. allow subsequent batches to be generated (to allow premature deletion of 3rd batch metadata) // 5. verify whether 3rd batch's block metadata still exists // - test("SPARK-6222: Do not clear received block data too soon") { + // TODO: SPARK-7420 enable this test + ignore("SPARK-6222: Do not clear received block data too soon") { import JobGeneratorSuite._ val checkpointDir = Utils.createTempDir() val testConf = conf From 23cf5af08d98da771c41571c00a2f5cafedfebdd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 9 Aug 2015 14:26:01 -0700 Subject: [PATCH 0142/1824] [SPARK-9703] [SQL] Refactor EnsureRequirements to avoid certain unnecessary shuffles This pull request refactors the `EnsureRequirements` planning rule in order to avoid the addition of certain unnecessary shuffles. As an example of how unnecessary shuffles can occur, consider SortMergeJoin, which requires clustered distribution and sorted ordering of its children's input rows. Say that both of SMJ's children produce unsorted output but are both SinglePartition. In this case, we will need to inject sort operators but should not need to inject Exchanges. Unfortunately, it looks like the EnsureRequirements unnecessarily repartitions using a hash partitioning. This patch solves this problem by refactoring `EnsureRequirements` to properly implement the `compatibleWith` checks that were broken in earlier implementations. See the significant inline comments for a better description of how this works. The majority of this PR is new comments and test cases, with few actual changes to the code. Author: Josh Rosen Closes #7988 from JoshRosen/exchange-fixes and squashes the following commits: 38006e7 [Josh Rosen] Rewrite EnsureRequirements _yet again_ to make things even simpler 0983f75 [Josh Rosen] More guarantees vs. compatibleWith cleanup; delete BroadcastPartitioning. 8784bd9 [Josh Rosen] Giant comment explaining compatibleWith vs. guarantees 1307c50 [Josh Rosen] Update conditions for requiring child compatibility. 18cddeb [Josh Rosen] Rename DummyPlan to DummySparkPlan. 2c7e126 [Josh Rosen] Merge remote-tracking branch 'origin/master' into exchange-fixes fee65c4 [Josh Rosen] Further refinement to comments / reasoning 642b0bb [Josh Rosen] Further expand comment / reasoning 06aba0c [Josh Rosen] Add more comments 8dbc845 [Josh Rosen] Add even more tests. 4f08278 [Josh Rosen] Fix the test by adding the compatibility check to EnsureRequirements a1c12b9 [Josh Rosen] Add failing test to demonstrate allCompatible bug 0725a34 [Josh Rosen] Small assertion cleanup. 5172ac5 [Josh Rosen] Add test for requiresChildrenToProduceSameNumberOfPartitions. 2e0f33a [Josh Rosen] Write a more generic test for EnsureRequirements. 752b8de [Josh Rosen] style fix c628daf [Josh Rosen] Revert accidental ExchangeSuite change. c9fb231 [Josh Rosen] Rewrite exchange to fix better handle this case. adcc742 [Josh Rosen] Move test to PlannerSuite. 0675956 [Josh Rosen] Preserving ordering and partitioning in row format converters also does not help. cc5669c [Josh Rosen] Adding outputPartitioning to Repartition does not fix the test. 2dfc648 [Josh Rosen] Add failing test illustrating bad exchange planning. --- .../plans/physical/partitioning.scala | 128 +++++++++++++-- .../apache/spark/sql/execution/Exchange.scala | 104 ++++++------ .../spark/sql/execution/basicOperators.scala | 5 + .../sql/execution/rowFormatConverters.scala | 5 + .../spark/sql/execution/PlannerSuite.scala | 151 ++++++++++++++++++ 5 files changed, 328 insertions(+), 65 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index ec659ce789c2..5a89a90b735a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -75,6 +75,37 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } +/** + * Describes how an operator's output is split across partitions. The `compatibleWith`, + * `guarantees`, and `satisfies` methods describe relationships between child partitionings, + * target partitionings, and [[Distribution]]s. These relations are described more precisely in + * their individual method docs, but at a high level: + * + * - `satisfies` is a relationship between partitionings and distributions. + * - `compatibleWith` is relationships between an operator's child output partitionings. + * - `guarantees` is a relationship between a child's existing output partitioning and a target + * output partitioning. + * + * Diagrammatically: + * + * +--------------+ + * | Distribution | + * +--------------+ + * ^ + * | + * satisfies + * | + * +--------------+ +--------------+ + * | Child | | Target | + * +----| Partitioning |----guarantees--->| Partitioning | + * | +--------------+ +--------------+ + * | ^ + * | | + * | compatibleWith + * | | + * +------------+ + * + */ sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int @@ -90,9 +121,66 @@ sealed trait Partitioning { /** * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] * guarantees the same partitioning scheme described by `other`. + * + * Compatibility of partitionings is only checked for operators that have multiple children + * and that require a specific child output [[Distribution]], such as joins. + * + * Intuitively, partitionings are compatible if they route the same partitioning key to the same + * partition. For instance, two hash partitionings are only compatible if they produce the same + * number of output partitionings and hash records according to the same hash function and + * same partitioning key schema. + * + * Put another way, two partitionings are compatible with each other if they satisfy all of the + * same distribution guarantees. */ - // TODO: Add an example once we have the `nullSafe` concept. - def guarantees(other: Partitioning): Boolean + def compatibleWith(other: Partitioning): Boolean + + /** + * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees + * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning + * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance + * optimization to allow the exchange planner to avoid redundant repartitionings. By default, + * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number + * of partitions, same strategy (range or hash), etc). + * + * In order to enable more aggressive optimization, this strict equality check can be relaxed. + * For example, say that the planner needs to repartition all of an operator's children so that + * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children + * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens + * to be hash-partitioned with a single partition then we do not need to re-shuffle this child; + * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees` + * [[SinglePartition]]. + * + * The SinglePartition example given above is not particularly interesting; guarantees' real + * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion + * of null-safe partitionings, under which partitionings can specify whether rows whose + * partitioning keys contain null values will be grouped into the same partition or whether they + * will have an unknown / random distribution. If a partitioning does not require nulls to be + * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered + * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot + * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a + * symmetric relation. + * + * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows + * produced by `A` could have also been produced by `B`. + */ + def guarantees(other: Partitioning): Boolean = this == other +} + +object Partitioning { + def allCompatible(partitionings: Seq[Partitioning]): Boolean = { + // Note: this assumes transitivity + partitionings.sliding(2).map { + case Seq(a) => true + case Seq(a, b) => + if (a.numPartitions != b.numPartitions) { + assert(!a.compatibleWith(b) && !b.compatibleWith(a)) + false + } else { + a.compatibleWith(b) && b.compatibleWith(a) + } + }.forall(_ == true) + } } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -101,6 +189,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } + override def compatibleWith(other: Partitioning): Boolean = false + override def guarantees(other: Partitioning): Boolean = false } @@ -109,21 +199,9 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def guarantees(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } -} - -case object BroadcastPartitioning extends Partitioning { - val numPartitions = 1 + override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 - override def satisfies(required: Distribution): Boolean = true - - override def guarantees(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case _ => false - } + override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } /** @@ -147,6 +225,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def compatibleWith(other: Partitioning): Boolean = other match { + case o: HashPartitioning => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case _ => false + } + override def guarantees(other: Partitioning): Boolean = other match { case o: HashPartitioning => this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions @@ -185,6 +269,11 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } + override def compatibleWith(other: Partitioning): Boolean = other match { + case o: RangePartitioning => this == o + case _ => false + } + override def guarantees(other: Partitioning): Boolean = other match { case o: RangePartitioning => this == o case _ => false @@ -228,6 +317,13 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def satisfies(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) + /** + * Returns true if any `partitioning` of this collection is compatible with + * the given [[Partitioning]]. + */ + override def compatibleWith(other: Partitioning): Boolean = + partitionings.exists(_.compatibleWith(other)) + /** * Returns true if any `partitioning` of this collection guarantees * the given [[Partitioning]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 49bb72980086..b89e634761eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -190,66 +190,72 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for * each operator by inserting [[Exchange]] Operators where required. Also ensure that the - * required input partition ordering requirements are met. + * input partition ordering requirements are met. */ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - def numPartitions: Int = sqlContext.conf.numShufflePartitions + private def numPartitions: Int = sqlContext.conf.numShufflePartitions - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator: SparkPlan => - // Adds Exchange or Sort operators as required - def addOperatorsIfNecessary( - partitioning: Partitioning, - rowOrdering: Seq[SortOrder], - child: SparkPlan): SparkPlan = { - - def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { - if (!child.outputPartitioning.guarantees(partitioning)) { - Exchange(partitioning, child) - } else { - child - } - } + /** + * Given a required distribution, returns a partitioning that satisfies that distribution. + */ + private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = { + requiredDistribution match { + case AllTuples => SinglePartition + case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) + case dist => sys.error(s"Do not know how to satisfy distribution $dist") + } + } - def addSortIfNecessary(child: SparkPlan): SparkPlan = { - - if (rowOrdering.nonEmpty) { - // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort. - val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min - if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) { - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) - } else { - child - } - } else { - child - } - } + private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { + val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution + val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering + var children: Seq[SparkPlan] = operator.children - addSortIfNecessary(addShuffleIfNecessary(child)) + // Ensure that the operator's children satisfy their output distribution requirements: + children = children.zip(requiredChildDistributions).map { case (child, distribution) => + if (child.outputPartitioning.satisfies(distribution)) { + child + } else { + Exchange(canonicalPartitioning(distribution), child) } + } - val requirements = - (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) - - val fixedChildren = requirements.zipped.map { - case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) - case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) - - case (UnspecifiedDistribution, Seq(), child) => + // If the operator has multiple children and specifies child output distributions (e.g. join), + // then the children's output partitionings must be compatible: + if (children.length > 1 + && requiredChildDistributions.toSet != Set(UnspecifiedDistribution) + && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + children = children.zip(requiredChildDistributions).map { case (child, distribution) => + val targetPartitioning = canonicalPartitioning(distribution) + if (child.outputPartitioning.guarantees(targetPartitioning)) { child - case (UnspecifiedDistribution, rowOrdering, child) => - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + } else { + Exchange(targetPartitioning, child) + } + } + } - case (dist, ordering, _) => - sys.error(s"Don't know how to ensure $dist with ordering $ordering") + // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: + children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + if (requiredOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. + val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min + if (minSize == 0 || requiredOrdering.take(minSize) != child.outputOrdering.take(minSize)) { + sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, global = false, child) + } else { + child + } + } else { + child } + } - operator.withNewChildren(fixedChildren) + operator.withNewChildren(children) + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator: SparkPlan => ensureDistributionAndOrdering(operator) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index c5d1ed0937b1..24950f26061f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -256,6 +256,11 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = { + if (numPartitions == 1) SinglePartition + else UnknownPartitioning(numPartitions) + } + protected override def doExecute(): RDD[InternalRow] = { child.execute().map(_.copy()).coalesce(numPartitions, shuffle) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 29f3beb3cb3c..855555dd1d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule /** @@ -33,6 +34,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false override def canProcessSafeRows: Boolean = true @@ -51,6 +54,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { @DeveloperApi case class ConvertToSafe(child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputsUnsafeRows: Boolean = false override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 18b0e54dc7c5..5582caa0d366 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite +import org.apache.spark.rdd.RDD import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} @@ -202,4 +206,151 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } } } + + // --- Unit tests of EnsureRequirements --------------------------------------------------------- + + // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, + // there two dimensions that need to be considered: are the child partitionings compatible and + // do they satisfy the distribution requirements? As a result, we need at least four test cases. + + private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { + if (outputPlan.children.length > 1 + && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { + val childPartitionings = outputPlan.children.map(_.outputPartitioning) + if (!Partitioning.allCompatible(childPartitionings)) { + fail(s"Partitionings are not compatible: $childPartitionings") + } + } + outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { + case (child, requiredDist) => + assert(child.outputPartitioning.satisfies(requiredDist), + s"$child output partitioning does not satisfy $requiredDist:\n$outputPlan") + } + } + + test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { + // Consider an operator that requires inputs that are clustered by two expressions (e.g. + // sort merge join where there are multiple columns in the equi-join condition) + val clusteringA = Literal(1) :: Nil + val clusteringB = Literal(2) :: Nil + val distribution = ClusteredDistribution(clusteringA ++ clusteringB) + // Say that the left and right inputs are each partitioned by _one_ of the two join columns: + val leftPartitioning = HashPartitioning(clusteringA, 1) + val rightPartitioning = HashPartitioning(clusteringB, 1) + // Individually, each input's partitioning satisfies the clustering distribution: + assert(leftPartitioning.satisfies(distribution)) + assert(rightPartitioning.satisfies(distribution)) + // However, these partitionings are not compatible with each other, so we still need to + // repartition both inputs prior to performing the join: + assert(!leftPartitioning.compatibleWith(rightPartitioning)) + assert(!rightPartitioning.compatibleWith(leftPartitioning)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = leftPartitioning), + DummySparkPlan(outputPartitioning = rightPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with child partitionings with different numbers of output partitions") { + // This is similar to the previous test, except it checks that partitionings are not compatible + // unless they produce the same number of partitions. + val clustering = Literal(1) :: Nil + val distribution = ClusteredDistribution(clustering) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 1)), + DummySparkPlan(outputPartitioning = HashPartitioning(clustering, 2)) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + } + + test("EnsureRequirements with compatible child partitionings that do not satisfy distribution") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + // The left and right inputs have compatible partitionings but they do not satisfy the + // distribution because they are clustered on different columns. Thus, we need to shuffle. + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 1) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) { + fail(s"Exchange should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements with compatible child partitionings that satisfy distribution") { + // In this case, all requirements are satisfied and no exchange should be added. + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + fail(s"Exchange should not have been added:\n$outputPlan") + } + } + + // This is a regression test for SPARK-9703 + test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") { + // Consider an operator that imposes both output distribution and ordering requirements on its + // children, such as sort sort merge join. If the distribution requirements are satisfied but + // the output ordering requirements are unsatisfied, then the planner should only add sorts and + // should not need to add additional shuffles / exchanges. + val outputOrdering = Seq(SortOrder(Literal(1), Ascending)) + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = SinglePartition), + DummySparkPlan(outputPartitioning = SinglePartition) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(outputOrdering, outputOrdering) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { + fail(s"No Exchanges should have been added:\n$outputPlan") + } + } + + // --------------------------------------------------------------------------------------------- +} + +// Used for unit-testing EnsureRequirements +private case class DummySparkPlan( + override val children: Seq[SparkPlan] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, + override val outputPartitioning: Partitioning = UnknownPartitioning(0), + override val requiredChildDistribution: Seq[Distribution] = Nil, + override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil + ) extends SparkPlan { + override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError + override def output: Seq[Attribute] = Seq.empty } From 46025616b414eaf1da01fcc1255d8041ea1554bc Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 9 Aug 2015 14:30:30 -0700 Subject: [PATCH 0143/1824] [CORE] [SPARK-9760] Use Option instead of Some for Ivy repos This was introduced in #7599 cc rxin brkyvz Author: Shivaram Venkataraman Closes #8055 from shivaram/spark-packages-repo-fix and squashes the following commits: 890f306 [Shivaram Venkataraman] Remove test case 51d69ee [Shivaram Venkataraman] Add test case for --packages without --repository c02e0b4 [Shivaram Venkataraman] Use Option instead of Some for Ivy repos --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1186bed48525..7ac6cbce4cd1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -284,7 +284,7 @@ object SparkSubmit { Nil } val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, - Some(args.repositories), Some(args.ivyRepoPath), exclusions = exclusions) + Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) if (args.isPython) { From be80def0d07ed0f45d60453f4f82500d8c4c9106 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 9 Aug 2015 22:33:53 -0700 Subject: [PATCH 0144/1824] [SPARK-9777] [SQL] Window operator can accept UnsafeRows https://issues.apache.org/jira/browse/SPARK-9777 Author: Yin Huai Closes #8064 from yhuai/windowUnsafe and squashes the following commits: 8fb3537 [Yin Huai] Set canProcessUnsafeRows to true. --- .../src/main/scala/org/apache/spark/sql/execution/Window.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index fe9f2c702817..0269d6d4b7a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -101,6 +101,8 @@ case class Window( override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def canProcessUnsafeRows: Boolean = true + /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. From e3fef0f9e17b1766a3869cb80ce7e4cd521cb7b6 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 10 Aug 2015 09:07:08 -0700 Subject: [PATCH 0145/1824] [SPARK-9743] [SQL] Fixes JSONRelation refreshing PR #7696 added two `HadoopFsRelation.refresh()` calls ([this] [1], and [this] [2]) in `DataSourceStrategy` to make test case `InsertSuite.save directly to the path of a JSON table` pass. However, this forces every `HadoopFsRelation` table scan to do a refresh, which can be super expensive for tables with large number of partitions. The reason why the original test case fails without the `refresh()` calls is that, the old JSON relation builds the base RDD with the input paths, while `HadoopFsRelation` provides `FileStatus`es of leaf files. With the old JSON relation, we can create a temporary table based on a path, writing data to that, and then read newly written data without refreshing the table. This is no long true for `HadoopFsRelation`. This PR removes those two expensive refresh calls, and moves the refresh into `JSONRelation` to fix this issue. We might want to update `HadoopFsRelation` interface to provide better support for this use case. [1]: https://github.com/apache/spark/blob/ebfd91c542aaead343cb154277fcf9114382fee7/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala#L63 [2]: https://github.com/apache/spark/blob/ebfd91c542aaead343cb154277fcf9114382fee7/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala#L91 Author: Cheng Lian Closes #8035 from liancheng/spark-9743/fix-json-relation-refreshing and squashes the following commits: ec1957d [Cheng Lian] Fixes JSONRelation refreshing --- .../datasources/DataSourceStrategy.scala | 2 -- .../apache/spark/sql/json/JSONRelation.scala | 19 +++++++++++++++---- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/sources/InsertSuite.scala | 10 +++++----- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 5b5fa8c93ec5..78a4acdf4b1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -60,7 +60,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Scanning partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) if t.partitionSpec.partitionColumns.nonEmpty => - t.refresh() val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray logInfo { @@ -88,7 +87,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Scanning non-partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation)) => - t.refresh() // See buildPartitionedTableScan for the reason that we need to create a shard // broadcast HadoopConf. val sharedHadoopConf = SparkHadoopUtil.get.conf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index b34a272ec547..5bb9e62310a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -22,20 +22,22 @@ import java.io.CharArrayWriter import com.fasterxml.jackson.core.JsonFactory import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{Text, LongWritable, NullWritable} +import org.apache.hadoop.io.{LongWritable, NullWritable, Text} import org.apache.hadoop.mapred.{JobConf, TextInputFormat} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} + import org.apache.spark.Logging +import org.apache.spark.broadcast.Broadcast import org.apache.spark.mapred.SparkHadoopMapRedUtil - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.util.SerializableConfiguration private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { @@ -108,6 +110,15 @@ private[sql] class JSONRelation( jsonSchema } + override private[sql] def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[String], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + refresh() + super.buildScan(requiredColumns, filters, inputPaths, broadcastedConf) + } + override def buildScan( requiredColumns: Array[String], filters: Array[Filter], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 4aafec0e2df2..6bcabbab4f77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -555,7 +555,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio }) } - private[sql] final def buildScan( + private[sql] def buildScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 39d18d712ef8..cdbfaf6455fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -32,9 +32,9 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { var path: File = null - override def beforeAll: Unit = { + override def beforeAll(): Unit = { path = Utils.createTempDir() - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") sql( s""" @@ -46,7 +46,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { """.stripMargin) } - override def afterAll: Unit = { + override def afterAll(): Unit = { caseInsensitiveContext.dropTempTable("jsonTable") caseInsensitiveContext.dropTempTable("jt") Utils.deleteRecursively(path) @@ -110,7 +110,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { ) // Writing the table to less part files. - val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) + val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 5) caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") sql( s""" @@ -122,7 +122,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { ) // Writing the table to more part files. - val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) + val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 10) caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") sql( s""" From 0f3366a4c740147a7a7519922642912e2dd238f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 10 Aug 2015 10:10:40 -0700 Subject: [PATCH 0146/1824] [SPARK-9710] [TEST] Fix RPackageUtilsSuite when R is not available. RUtils.isRInstalled throws an exception if R is not installed, instead of returning false. Fix that. Author: Marcelo Vanzin Closes #8008 from vanzin/SPARK-9710 and squashes the following commits: df72d8c [Marcelo Vanzin] [SPARK-9710] [test] Fix RPackageUtilsSuite when R is not available. --- core/src/main/scala/org/apache/spark/api/r/RUtils.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 93b3bea57867..427b2bc7cbcb 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -67,7 +67,11 @@ private[spark] object RUtils { /** Check if R is installed before running tests that use R commands. */ def isRInstalled: Boolean = { - val builder = new ProcessBuilder(Seq("R", "--version")) - builder.start().waitFor() == 0 + try { + val builder = new ProcessBuilder(Seq("R", "--version")) + builder.start().waitFor() == 0 + } catch { + case e: Exception => false + } } } From 00b655cced637e1c3b750c19266086b9dcd7c158 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 10 Aug 2015 11:01:45 -0700 Subject: [PATCH 0147/1824] [SPARK-9755] [MLLIB] Add docs to MultivariateOnlineSummarizer methods Adds method documentations back to `MultivariateOnlineSummarizer`, which were present in 1.4 but disappeared somewhere along the way to 1.5. jkbradley Author: Feynman Liang Closes #8045 from feynmanliang/SPARK-9755 and squashes the following commits: af67fde [Feynman Liang] Add MultivariateOnlineSummarizer docs --- .../stat/MultivariateOnlineSummarizer.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 62da9f2ef22a..64e4be0ebb97 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -153,6 +153,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Sample mean of each dimension. + * * @since 1.1.0 */ override def mean: Vector = { @@ -168,6 +170,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Sample variance of each dimension. + * * @since 1.1.0 */ override def variance: Vector = { @@ -193,11 +197,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Sample size. + * * @since 1.1.0 */ override def count: Long = totalCnt /** + * Number of nonzero elements in each dimension. + * * @since 1.1.0 */ override def numNonzeros: Vector = { @@ -207,6 +215,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Maximum value of each dimension. + * * @since 1.1.0 */ override def max: Vector = { @@ -221,6 +231,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * Minimum value of each dimension. + * * @since 1.1.0 */ override def min: Vector = { @@ -235,6 +247,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * L2 (Euclidian) norm of each dimension. + * * @since 1.2.0 */ override def normL2: Vector = { @@ -252,6 +266,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } /** + * L1 norm of each dimension. + * * @since 1.2.0 */ override def normL1: Vector = { From d285212756168200383bf4df2c951bd80a492a7c Mon Sep 17 00:00:00 2001 From: Mahmoud Lababidi Date: Mon, 10 Aug 2015 13:02:01 -0700 Subject: [PATCH 0148/1824] Fixed AtmoicReference<> Example Author: Mahmoud Lababidi Closes #8076 from lababidi/master and squashes the following commits: af4553b [Mahmoud Lababidi] Fixed AtmoicReference<> Example --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 775d508d4879..7571e22575ef 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -152,7 +152,7 @@ Next, we discuss how to use this approach in your streaming application.
    // Hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference(); + final AtomicReference offsetRanges = new AtomicReference<>(); directKafkaStream.transformToPair( new Function, JavaPairRDD>() { From 0fe66744f16854fc8cd8a72174de93a788e3cf6c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 10 Aug 2015 13:05:03 -0700 Subject: [PATCH 0149/1824] [SPARK-9784] [SQL] Exchange.isUnsafe should check whether codegen and unsafe are enabled Exchange.isUnsafe should check whether codegen and unsafe are enabled. Author: Josh Rosen Closes #8073 from JoshRosen/SPARK-9784 and squashes the following commits: 7a1019f [Josh Rosen] [SPARK-9784] Exchange.isUnsafe should check whether codegen and unsafe are enabled --- .../main/scala/org/apache/spark/sql/execution/Exchange.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index b89e634761eb..029f2264a6a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -46,7 +46,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * Returns true iff we can support the data type, and we are not doing range partitioning. */ private lazy val tungstenMode: Boolean = { - GenerateUnsafeProjection.canSupport(child.schema) && + unsafeEnabled && codegenEnabled && GenerateUnsafeProjection.canSupport(child.schema) && !newPartitioning.isInstanceOf[RangePartitioning] } From 40ed2af587cedadc6e5249031857a922b3b234ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 10 Aug 2015 13:49:23 -0700 Subject: [PATCH 0150/1824] [SPARK-9763][SQL] Minimize exposure of internal SQL classes. There are a few changes in this pull request: 1. Moved all data sources to execution.datasources, except the public JDBC APIs. 2. In order to maintain backward compatibility from 1, added a backward compatibility translation map in data source resolution. 3. Moved ui and metric package into execution. 4. Added more documentation on some internal classes. 5. Renamed DataSourceRegister.format -> shortName. 6. Added "override" modifier on shortName. 7. Removed IntSQLMetric. Author: Reynold Xin Closes #8056 from rxin/SPARK-9763 and squashes the following commits: 9df4801 [Reynold Xin] Removed hardcoded name in test cases. d9babc6 [Reynold Xin] Shorten. e484419 [Reynold Xin] Removed VisibleForTesting. 171b812 [Reynold Xin] MimaExcludes. 2041389 [Reynold Xin] Compile ... 79dda42 [Reynold Xin] Compile. 0818ba3 [Reynold Xin] Removed IntSQLMetric. c46884f [Reynold Xin] Two more fixes. f9aa88d [Reynold Xin] [SPARK-9763][SQL] Minimize exposure of internal SQL classes. --- project/MimaExcludes.scala | 24 +- ...pache.spark.sql.sources.DataSourceRegister | 6 +- .../ui/static/spark-sql-viz.css | 0 .../ui/static/spark-sql-viz.js | 0 .../org/apache/spark/sql/DataFrame.scala | 2 +- .../apache/spark/sql/DataFrameReader.scala | 6 +- .../apache/spark/sql/DataFrameWriter.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/SQLExecution.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 8 +- .../spark/sql/execution/basicOperators.scala | 2 +- .../sql/execution/datasources/DDLParser.scala | 185 +++++++++ .../execution/datasources/DefaultSource.scala | 64 ++++ .../datasources/InsertIntoDataSource.scala | 23 +- .../datasources/ResolvedDataSource.scala | 204 ++++++++++ .../spark/sql/execution/datasources/ddl.scala | 352 +----------------- .../datasources/jdbc/DefaultSource.scala | 62 +++ .../datasources/jdbc/DriverRegistry.scala | 60 +++ .../datasources/jdbc/DriverWrapper.scala | 48 +++ .../datasources}/jdbc/JDBCRDD.scala | 9 +- .../datasources}/jdbc/JDBCRelation.scala | 41 +- .../datasources/jdbc/JdbcUtils.scala | 219 +++++++++++ .../datasources}/json/InferSchema.scala | 4 +- .../datasources}/json/JSONRelation.scala | 7 +- .../datasources}/json/JacksonGenerator.scala | 2 +- .../datasources}/json/JacksonParser.scala | 4 +- .../datasources}/json/JacksonUtils.scala | 2 +- .../parquet/CatalystReadSupport.scala | 2 +- .../parquet/CatalystRecordMaterializer.scala | 2 +- .../parquet/CatalystRowConverter.scala | 2 +- .../parquet/CatalystSchemaConverter.scala | 4 +- .../DirectParquetOutputCommitter.scala | 2 +- .../parquet/ParquetConverter.scala | 2 +- .../datasources}/parquet/ParquetFilters.scala | 2 +- .../parquet/ParquetRelation.scala | 4 +- .../parquet/ParquetTableSupport.scala | 2 +- .../parquet/ParquetTypesConverter.scala | 2 +- .../{ => execution}/metric/SQLMetrics.scala | 36 +- .../apache/spark/sql/execution/package.scala | 8 +- .../ui/AllExecutionsPage.scala | 2 +- .../{ => execution}/ui/ExecutionPage.scala | 2 +- .../sql/{ => execution}/ui/SQLListener.scala | 7 +- .../spark/sql/{ => execution}/ui/SQLTab.scala | 6 +- .../{ => execution}/ui/SparkPlanGraph.scala | 4 +- .../org/apache/spark/sql/jdbc/JdbcUtils.scala | 52 --- .../org/apache/spark/sql/jdbc/jdbc.scala | 250 ------------- .../apache/spark/sql/sources/interfaces.scala | 15 +- .../parquet/test/avro/CompatibilityTest.java | 4 +- .../parquet/test/avro/Nested.java | 30 +- .../parquet/test/avro/ParquetAvroCompat.java | 106 +++--- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../datasources}/json/JsonSuite.scala | 4 +- .../datasources}/json/TestJsonData.scala | 2 +- .../ParquetAvroCompatibilitySuite.scala | 4 +- .../parquet/ParquetCompatibilityTest.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 2 +- .../datasources}/parquet/ParquetIOSuite.scala | 4 +- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../parquet/ParquetQuerySuite.scala | 2 +- .../parquet/ParquetSchemaSuite.scala | 2 +- .../datasources}/parquet/ParquetTest.scala | 2 +- .../ParquetThriftCompatibilitySuite.scala | 2 +- .../metric/SQLMetricsSuite.scala | 12 +- .../{ => execution}/ui/SQLListenerSuite.scala | 4 +- .../sources/CreateTableAsSelectSuite.scala | 20 +- .../sql/sources/DDLSourceLoadSuite.scala | 59 +-- .../sql/sources/ResolvedDataSourceSuite.scala | 39 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../spark/sql/hive/orc/OrcRelation.scala | 6 +- .../spark/sql/hive/HiveParquetSuite.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../hive/ParquetHiveCompatibilitySuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- .../apache/spark/sql/hive/parquetSuites.scala | 2 +- .../ParquetHadoopFsRelationSuite.scala | 4 +- .../SimpleTextHadoopFsRelationSuite.scala | 2 +- 76 files changed, 1114 insertions(+), 966 deletions(-) rename sql/core/src/main/resources/org/apache/spark/sql/{ => execution}/ui/static/spark-sql-viz.css (100%) rename sql/core/src/main/resources/org/apache/spark/sql/{ => execution}/ui/static/spark-sql-viz.js (100%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/jdbc/JDBCRDD.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/jdbc/JDBCRelation.scala (71%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/InferSchema.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JSONRelation.scala (97%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JacksonGenerator.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JacksonParser.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/json/JacksonUtils.scala (95%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystReadSupport.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystRecordMaterializer.scala (96%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystRowConverter.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/CatalystSchemaConverter.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/DirectParquetOutputCommitter.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetConverter.scala (96%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetFilters.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetRelation.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetTableSupport.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetTypesConverter.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/metric/SQLMetrics.scala (77%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/AllExecutionsPage.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/ExecutionPage.scala (99%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/SQLListener.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/SQLTab.scala (90%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/ui/SparkPlanGraph.scala (97%) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala rename sql/core/src/test/gen-java/org/apache/spark/sql/{ => execution/datasources}/parquet/test/avro/CompatibilityTest.java (93%) rename sql/core/src/test/gen-java/org/apache/spark/sql/{ => execution/datasources}/parquet/test/avro/Nested.java (78%) rename sql/core/src/test/gen-java/org/apache/spark/sql/{ => execution/datasources}/parquet/test/avro/ParquetAvroCompat.java (83%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/json/JsonSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/json/TestJsonData.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetAvroCompatibilitySuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetCompatibilityTest.scala (97%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetFilterSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetIOSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetPartitionDiscoverySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetQuerySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetSchemaSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetTest.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution/datasources}/parquet/ParquetThriftCompatibilitySuite.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/metric/SQLMetricsSuite.scala (92%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/ui/SQLListenerSuite.scala (99%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b60ae784c379..90261ca3d61a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,8 +62,6 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticCostFun.this"), // SQL execution is considered private. excludePackage("org.apache.spark.sql.execution"), - // Parquet support is considered private. - excludePackage("org.apache.spark.sql.parquet"), // The old JSON RDD is removed in favor of streaming Jackson ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), @@ -155,7 +153,27 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), + // SPARK-9763 Minimize exposure of internal SQL classes + excludePackage("org.apache.spark.sql.parquet"), + excludePackage("org.apache.spark.sql.json"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") ) ++ Seq( // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index cc32d4b72748..ca50000b4756 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,3 @@ -org.apache.spark.sql.jdbc.DefaultSource -org.apache.spark.sql.json.DefaultSource -org.apache.spark.sql.parquet.DefaultSource +org.apache.spark.sql.execution.datasources.jdbc.DefaultSource +org.apache.spark.sql.execution.datasources.json.DefaultSource +org.apache.spark.sql.execution.datasources.parquet.DefaultSource diff --git a/sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css similarity index 100% rename from sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.css rename to sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css diff --git a/sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.js b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js similarity index 100% rename from sql/core/src/main/resources/org/apache/spark/sql/ui/static/spark-sql-viz.js rename to sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 570b8b2d5928..27b994f1f0ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} -import org.apache.spark.sql.json.JacksonGenerator +import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 85f33c5e9952..9ea955b01001 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -25,10 +25,10 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types.StructType import org.apache.spark.{Logging, Partition} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 2a4992db09bc..5fa11da4c38c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,8 +23,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} -import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} import org.apache.spark.sql.sources.HadoopFsRelation @@ -264,7 +264,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { // Create the table if the table didn't exist. if (!tableExists) { - val schema = JDBCWriteDetails.schemaString(df, url) + val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" conn.prepareStatement(sql).executeUpdate() } @@ -272,7 +272,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { conn.close() } - JDBCWriteDetails.saveTable(df, url, table, connectionProperties) + JdbcUtils.saveTable(df, url, table, connectionProperties) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 832572571cab..f73bb0488c98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.ui.{SQLListener, SQLTab} +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 97f1323e9783..cee58218a885 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.util.Utils private[sql] object SQLExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 1915496d1620..9ba5cf2d2b39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.DataType object SparkPlan { @@ -98,12 +98,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics - /** - * Return a IntSQLMetric according to the name. - */ - private[sql] def intMetric(name: String): IntSQLMetric = - metrics(name).asInstanceOf[IntSQLMetric] - /** * Return a LongSQLMetric according to the name. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 24950f26061f..bf2de244c8e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala new file mode 100644 index 000000000000..6c462fa30461 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -0,0 +1,185 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import scala.language.implicitConversions +import scala.util.matching.Regex + +import org.apache.spark.Logging +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ + + +/** + * A parser for foreign DDL commands. + */ +class DDLParser(parseQuery: String => LogicalPlan) + extends AbstractSparkSQLParser with DataTypeParser with Logging { + + def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { + try { + parse(input) + } catch { + case ddlException: DDLException => throw ddlException + case _ if !exceptionOnError => parseQuery(input) + case x: Throwable => throw x + } + } + + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` + // properties via reflection the class in runtime for constructing the SqlLexical object + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val IF = Keyword("IF") + protected val NOT = Keyword("NOT") + protected val EXISTS = Keyword("EXISTS") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") + protected val AS = Keyword("AS") + protected val COMMENT = Keyword("COMMENT") + protected val REFRESH = Keyword("REFRESH") + + protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable + + protected def start: Parser[LogicalPlan] = ddl + + /** + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * AS SELECT ... + */ + protected lazy val createTable: Parser[LogicalPlan] = { + // TODO: Support database.table. + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ + tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { + case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => + if (temp.isDefined && allowExisting.isDefined) { + throw new DDLException( + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") + } + + val options = opts.getOrElse(Map.empty[String, String]) + if (query.isDefined) { + if (columns.isDefined) { + throw new DDLException( + "a CREATE TABLE AS SELECT statement does not allow column definitions.") + } + // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. + val mode = if (allowExisting.isDefined) { + SaveMode.Ignore + } else if (temp.isDefined) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + + val queryPlan = parseQuery(query.get) + CreateTableUsingAsSelect(tableName, + provider, + temp.isDefined, + Array.empty[String], + mode, + options, + queryPlan) + } else { + val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) + CreateTableUsing( + tableName, + userSpecifiedSchema, + provider, + temp.isDefined, + options, + allowExisting.isDefined, + managedIfNoPath = false) + } + } + } + + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + + /* + * describe [extended] table avroTable + * This will display all columns of table `avroTable` includes column_name,column_type,comment + */ + protected lazy val describeTable: Parser[LogicalPlan] = + (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { + case e ~ db ~ tbl => + val tblIdentifier = db match { + case Some(dbName) => + Seq(dbName, tbl) + case None => + Seq(tbl) + } + DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) + } + + protected lazy val refreshTable: Parser[LogicalPlan] = + REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { + case maybeDatabaseName ~ tableName => + RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex $regex", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { + case parts => parts.mkString(".") + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k, v) } + + protected lazy val column: Parser[StructField] = + ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => + val meta = cm match { + case Some(comment) => + new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() + case None => Metadata.empty + } + + StructField(columnName, typ, nullable = true, meta) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala new file mode 100644 index 000000000000..6e4cc4de7f65 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala @@ -0,0 +1,64 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import java.util.Properties + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCRelation, JDBCPartitioningInfo, DriverRegistry} +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} + + +class DefaultSource extends RelationProvider with DataSourceRegister { + + override def shortName(): String = "jdbc" + + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val driver = parameters.getOrElse("driver", null) + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val partitionColumn = parameters.getOrElse("partitionColumn", null) + val lowerBound = parameters.getOrElse("lowerBound", null) + val upperBound = parameters.getOrElse("upperBound", null) + val numPartitions = parameters.getOrElse("numPartitions", null) + + if (driver != null) DriverRegistry.register(driver) + + if (partitionColumn != null + && (lowerBound == null || upperBound == null || numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) + } + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala index 6ccde7693bd3..3b7dc2e8d021 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala @@ -17,27 +17,10 @@ package org.apache.spark.sql.execution.datasources -import java.io.IOException -import java.util.{Date, UUID} - -import scala.collection.JavaConversions.asScalaIterator - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} -import org.apache.spark._ -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution} -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.StringType -import org.apache.spark.util.{Utils, SerializableConfiguration} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.sources.InsertableRelation /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala new file mode 100644 index 000000000000..7770bbd712f0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -0,0 +1,204 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources + +import java.util.ServiceLoader + +import scala.collection.JavaConversions._ +import scala.language.{existentials, implicitConversions} +import scala.util.{Success, Failure, Try} + +import org.apache.hadoop.fs.Path + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{DataFrame, SaveMode, AnalysisException, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.util.Utils + + +case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) + + +object ResolvedDataSource extends Logging { + + /** A map to maintain backward compatibility in case we move data sources around. */ + private val backwardCompatibilityMap = Map( + "org.apache.spark.sql.jdbc" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.jdbc.DefaultSource" -> classOf[jdbc.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.json.DefaultSource" -> classOf[json.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet" -> classOf[parquet.DefaultSource].getCanonicalName, + "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName + ) + + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider0: String): Class[_] = { + val provider = backwardCompatibilityMap.getOrElse(provider0, provider0) + val provider2 = s"$provider.DefaultSource" + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.iterator().filter(_.shortName().equalsIgnoreCase(provider)).toList match { + /** the provider format did not match any given registered aliases */ + case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => dataSource + case Failure(error) => + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + throw new ClassNotFoundException( + s"Failed to load class for data source: $provider.", error) + } + } + /** there is exactly one registered alias */ + case head :: Nil => head.getClass + /** There are multiple registered aliases for the input */ + case sources => sys.error(s"Multiple sources found for $provider, " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") + } + } + + /** Create a [[ResolvedDataSource]] for reading data in. */ + def apply( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String]): ResolvedDataSource = { + val clazz: Class[_] = lookupDataSource(provider) + def className: String = clazz.getCanonicalName + val relation = userSpecifiedSchema match { + case Some(schema: StructType) => clazz.newInstance() match { + case dataSource: SchemaRelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: HadoopFsRelationProvider => + val maybePartitionsSchema = if (partitionColumns.isEmpty) { + None + } else { + Some(partitionColumnsSchema(schema, partitionColumns)) + } + + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + val patternPath = new Path(caseInsensitiveOptions("path")) + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + } + + val dataSchema = + StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable + + dataSource.createRelation( + sqlContext, + paths, + Some(dataSchema), + maybePartitionsSchema, + caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.RelationProvider => + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") + } + + case None => clazz.newInstance() match { + case dataSource: RelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: HadoopFsRelationProvider => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val paths = { + val patternPath = new Path(caseInsensitiveOptions("path")) + val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray + } + dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException( + s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") + } + } + new ResolvedDataSource(clazz, relation) + } + + private def partitionColumnsSchema( + schema: StructType, + partitionColumns: Array[String]): StructType = { + StructType(partitionColumns.map { col => + schema.find(_.name == col).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema $schema") + } + }).asNullable + } + + /** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */ + def apply( + sqlContext: SQLContext, + provider: String, + partitionColumns: Array[String], + mode: SaveMode, + options: Map[String, String], + data: DataFrame): ResolvedDataSource = { + if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + throw new AnalysisException("Cannot save interval data type into external storage.") + } + val clazz: Class[_] = lookupDataSource(provider) + val relation = clazz.newInstance() match { + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, mode, options, data) + case dataSource: HadoopFsRelationProvider => + // Don't glob path for the write path. The contracts here are: + // 1. Only one output path can be specified on the write path; + // 2. Output path must be a legal HDFS style file system path; + // 3. It's OK that the output path doesn't exist yet; + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val outputPath = { + val path = new Path(caseInsensitiveOptions("path")) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) + val r = dataSource.createRelation( + sqlContext, + Array(outputPath.toString), + Some(dataSchema.asNullable), + Some(partitionColumnsSchema(data.schema, partitionColumns)), + caseInsensitiveOptions) + + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. + sqlContext.executePlan( + InsertIntoHadoopFsRelation( + r, + data.logicalPlan, + mode)).toRdd + r + case _ => + sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") + } + ResolvedDataSource(clazz, relation) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 8c2f297e4245..ecd304c30cde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,340 +17,12 @@ package org.apache.spark.sql.execution.datasources -import java.util.ServiceLoader - -import scala.collection.Iterator -import scala.collection.JavaConversions._ -import scala.language.{existentials, implicitConversions} -import scala.util.{Failure, Success, Try} -import scala.util.matching.Regex - -import org.apache.hadoop.fs.Path - -import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} -import org.apache.spark.util.Utils - -/** - * A parser for foreign DDL commands. - */ -private[sql] class DDLParser( - parseQuery: String => LogicalPlan) - extends AbstractSparkSQLParser with DataTypeParser with Logging { - - def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { - try { - parse(input) - } catch { - case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => parseQuery(input) - case x: Throwable => throw x - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val CREATE = Keyword("CREATE") - protected val TEMPORARY = Keyword("TEMPORARY") - protected val TABLE = Keyword("TABLE") - protected val IF = Keyword("IF") - protected val NOT = Keyword("NOT") - protected val EXISTS = Keyword("EXISTS") - protected val USING = Keyword("USING") - protected val OPTIONS = Keyword("OPTIONS") - protected val DESCRIBE = Keyword("DESCRIBE") - protected val EXTENDED = Keyword("EXTENDED") - protected val AS = Keyword("AS") - protected val COMMENT = Keyword("COMMENT") - protected val REFRESH = Keyword("REFRESH") - - protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable - - protected def start: Parser[LogicalPlan] = ddl - - /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] - * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` - * AS SELECT ... - */ - protected lazy val createTable: Parser[LogicalPlan] = - // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ - tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => - if (temp.isDefined && allowExisting.isDefined) { - throw new DDLException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - val options = opts.getOrElse(Map.empty[String, String]) - if (query.isDefined) { - if (columns.isDefined) { - throw new DDLException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - // When IF NOT EXISTS clause appears in the query, the save mode will be ignore. - val mode = if (allowExisting.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - val queryPlan = parseQuery(query.get) - CreateTableUsingAsSelect(tableName, - provider, - temp.isDefined, - Array.empty[String], - mode, - options, - queryPlan) - } else { - val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) - CreateTableUsing( - tableName, - userSpecifiedSchema, - provider, - temp.isDefined, - options, - allowExisting.isDefined, - managedIfNoPath = false) - } - } - - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" - - /* - * describe [extended] table avroTable - * This will display all columns of table `avroTable` includes column_name,column_type,comment - */ - protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => - val tblIdentifier = db match { - case Some(dbName) => - Seq(dbName, tbl) - case None => - Seq(tbl) - } - DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) - } - - protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { - case maybeDatabaseName ~ tableName => - RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) - } - - protected lazy val options: Parser[Map[String, String]] = - "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } - - protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - - override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( - s"identifier matching regex $regex", { - case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str - case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str - } - ) - - protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ { - case name => name - } - - protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ { - case parts => parts.mkString(".") - } - - protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k, v) } - - protected lazy val column: Parser[StructField] = - ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => - val meta = cm match { - case Some(comment) => - new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build() - case None => Metadata.empty - } - - StructField(columnName, typ, nullable = true, meta) - } -} - -private[sql] object ResolvedDataSource extends Logging { - - /** Given a provider name, look up the data source class definition. */ - def lookupDataSource(provider: String): Class[_] = { - val provider2 = s"$provider.DefaultSource" - val loader = Utils.getContextOrSparkClassLoader - val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) - - serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match { - /** the provider format did not match any given registered aliases */ - case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => dataSource - case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new ClassNotFoundException( - "The ORC data source must be used with Hive support enabled.", error) - } else { - throw new ClassNotFoundException( - s"Failed to load class for data source: $provider", error) - } - } - /** there is exactly one registered alias */ - case head :: Nil => head.getClass - /** There are multiple registered aliases for the input */ - case sources => sys.error(s"Multiple sources found for $provider, " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name") - } - } - - /** Create a [[ResolvedDataSource]] for reading data in. */ - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - provider: String, - options: Map[String, String]): ResolvedDataSource = { - val clazz: Class[_] = lookupDataSource(provider) - def className: String = clazz.getCanonicalName - val relation = userSpecifiedSchema match { - case Some(schema: StructType) => clazz.newInstance() match { - case dataSource: SchemaRelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: HadoopFsRelationProvider => - val maybePartitionsSchema = if (partitionColumns.isEmpty) { - None - } else { - Some(partitionColumnsSchema(schema, partitionColumns)) - } - - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - - val dataSchema = - StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable - - dataSource.createRelation( - sqlContext, - paths, - Some(dataSchema), - maybePartitionsSchema, - caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - throw new AnalysisException(s"$className does not allow user-specified schemas.") - case _ => - throw new AnalysisException(s"$className is not a RelationProvider.") - } - - case None => clazz.newInstance() match { - case dataSource: RelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: HadoopFsRelationProvider => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val paths = { - val patternPath = new Path(caseInsensitiveOptions("path")) - val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray - } - dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - throw new AnalysisException( - s"A schema needs to be specified when using $className.") - case _ => - throw new AnalysisException( - s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") - } - } - new ResolvedDataSource(clazz, relation) - } - - private def partitionColumnsSchema( - schema: StructType, - partitionColumns: Array[String]): StructType = { - StructType(partitionColumns.map { col => - schema.find(_.name == col).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $schema") - } - }).asNullable - } - - /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ - def apply( - sqlContext: SQLContext, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { - throw new AnalysisException("Cannot save interval data type into external storage.") - } - val clazz: Class[_] = lookupDataSource(provider) - val relation = clazz.newInstance() match { - case dataSource: CreatableRelationProvider => - dataSource.createRelation(sqlContext, mode, options, data) - case dataSource: HadoopFsRelationProvider => - // Don't glob path for the write path. The contracts here are: - // 1. Only one output path can be specified on the write path; - // 2. Output path must be a legal HDFS style file system path; - // 3. It's OK that the output path doesn't exist yet; - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val outputPath = { - val path = new Path(caseInsensitiveOptions("path")) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - path.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) - val r = dataSource.createRelation( - sqlContext, - Array(outputPath.toString), - Some(dataSchema.asNullable), - Some(partitionColumnsSchema(data.schema, partitionColumns)), - caseInsensitiveOptions) - - // For partitioned relation r, r.schema's column ordering can be different from the column - // ordering of data.logicalPlan (partition columns are all moved after data column). This - // will be adjusted within InsertIntoHadoopFsRelation. - sqlContext.executePlan( - InsertIntoHadoopFsRelation( - r, - data.logicalPlan, - mode)).toRdd - r - case _ => - sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") - } - new ResolvedDataSource(clazz, relation) - } -} - -private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. @@ -358,11 +30,12 @@ private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRel * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. * It is effective only when the table is a Hive table. */ -private[sql] case class DescribeCommand( +case class DescribeCommand( table: LogicalPlan, isExtended: Boolean) extends LogicalPlan with Command { override def children: Seq[LogicalPlan] = Seq.empty + override val output: Seq[Attribute] = Seq( // Column names are based on Hive. AttributeReference("col_name", StringType, nullable = false, @@ -370,7 +43,8 @@ private[sql] case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), AttributeReference("comment", StringType, nullable = false, - new MetadataBuilder().putString("comment", "comment of the column").build())()) + new MetadataBuilder().putString("comment", "comment of the column").build())() + ) } /** @@ -378,7 +52,7 @@ private[sql] case class DescribeCommand( * @param allowExisting If it is true, we will do nothing when the table already exists. * If it is false, an exception will be thrown */ -private[sql] case class CreateTableUsing( +case class CreateTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, @@ -397,7 +71,7 @@ private[sql] case class CreateTableUsing( * can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ -private[sql] case class CreateTableUsingAsSelect( +case class CreateTableUsingAsSelect( tableName: String, provider: String, temporary: Boolean, @@ -410,7 +84,7 @@ private[sql] case class CreateTableUsingAsSelect( // override lazy val resolved = databaseName != None && childrenResolved } -private[sql] case class CreateTempTableUsing( +case class CreateTempTableUsing( tableName: String, userSpecifiedSchema: Option[StructType], provider: String, @@ -425,7 +99,7 @@ private[sql] case class CreateTempTableUsing( } } -private[sql] case class CreateTempTableUsingAsSelect( +case class CreateTempTableUsingAsSelect( tableName: String, provider: String, partitionColumns: Array[String], @@ -443,7 +117,7 @@ private[sql] case class CreateTempTableUsingAsSelect( } } -private[sql] case class RefreshTable(tableIdent: TableIdentifier) +case class RefreshTable(tableIdent: TableIdentifier) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -472,7 +146,7 @@ private[sql] case class RefreshTable(tableIdent: TableIdentifier) /** * Builds a map in which keys are case insensitive */ -protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) @@ -490,4 +164,4 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St /** * The exception thrown from the DDL parser. */ -protected[sql] class DDLException(message: String) extends Exception(message) +class DDLException(message: String) extends RuntimeException(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala new file mode 100644 index 000000000000..6773afc794f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -0,0 +1,62 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.util.Properties + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister} + +class DefaultSource extends RelationProvider with DataSourceRegister { + + override def shortName(): String = "jdbc" + + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val driver = parameters.getOrElse("driver", null) + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val partitionColumn = parameters.getOrElse("partitionColumn", null) + val lowerBound = parameters.getOrElse("lowerBound", null) + val upperBound = parameters.getOrElse("upperBound", null) + val numPartitions = parameters.getOrElse("numPartitions", null) + + if (driver != null) DriverRegistry.register(driver) + + if (partitionColumn != null + && (lowerBound == null || upperBound == null || numPartitions == null)) { + sys.error("Partitioning incompletely specified") + } + + val partitionInfo = if (partitionColumn == null) { + null + } else { + JDBCPartitioningInfo( + partitionColumn, + lowerBound.toLong, + upperBound.toLong, + numPartitions.toInt) + } + val parts = JDBCRelation.columnPartition(partitionInfo) + val properties = new Properties() // Additional properties that we will pass to getConnection + parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) + JDBCRelation(url, table, parts, properties)(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala new file mode 100644 index 000000000000..7ccd61ed469e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Driver, DriverManager} + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * java.sql.DriverManager is always loaded by bootstrap classloader, + * so it can't load JDBC drivers accessible by Spark ClassLoader. + * + * To solve the problem, drivers from user-supplied jars are wrapped into thin wrapper. + */ +object DriverRegistry extends Logging { + + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty + + def register(className: String): Unit = { + val cls = Utils.getContextOrSparkClassLoader.loadClass(className) + if (cls.getClassLoader == null) { + logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") + } else if (wrapperMap.get(className).isDefined) { + logTrace(s"Wrapper for $className already exists") + } else { + synchronized { + if (wrapperMap.get(className).isEmpty) { + val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) + DriverManager.registerDriver(wrapper) + wrapperMap(className) = wrapper + logTrace(s"Wrapper for $className registered") + } + } + } + } + + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { + case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala new file mode 100644 index 000000000000..18263fe227d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, Driver, DriverPropertyInfo, SQLFeatureNotSupportedException} +import java.util.Properties + +/** + * A wrapper for a JDBC Driver to work around SPARK-6913. + * + * The problem is in `java.sql.DriverManager` class that can't access drivers loaded by + * Spark ClassLoader. + */ +class DriverWrapper(val wrapped: Driver) extends Driver { + override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) + + override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() + + override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { + wrapped.getPropertyInfo(url, info) + } + + override def getMinorVersion: Int = wrapped.getMinorVersion + + def getParentLogger: java.util.logging.Logger = { + throw new SQLFeatureNotSupportedException( + s"${this.getClass.getName}.getParentLogger is not yet implemented.") + } + + override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) + + override def getMajorVersion: Int = wrapped.getMajorVersion +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 3cf70db6b7b0..8eab6a0adccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -180,9 +181,8 @@ private[sql] object JDBCRDD extends Logging { try { if (driver != null) DriverRegistry.register(driver) } catch { - case e: ClassNotFoundException => { - logWarning(s"Couldn't find class $driver", e); - } + case e: ClassNotFoundException => + logWarning(s"Couldn't find class $driver", e) } DriverManager.getConnection(url, properties) } @@ -344,7 +344,6 @@ private[sql] class JDBCRDD( }).toArray } - /** * Runs the SQL query against the JDBC driver. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala similarity index 71% rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 48d97ced9ca0..f9300dc2cb52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.jdbc +package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties @@ -77,45 +77,6 @@ private[sql] object JDBCRelation { } } -private[sql] class DefaultSource extends RelationProvider with DataSourceRegister { - - def format(): String = "jdbc" - - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) - val partitionColumn = parameters.getOrElse("partitionColumn", null) - val lowerBound = parameters.getOrElse("lowerBound", null) - val upperBound = parameters.getOrElse("upperBound", null) - val numPartitions = parameters.getOrElse("numPartitions", null) - - if (driver != null) DriverRegistry.register(driver) - - if (partitionColumn != null - && (lowerBound == null || upperBound == null || numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } - - val partitionInfo = if (partitionColumn == null) { - null - } else { - JDBCPartitioningInfo( - partitionColumn, - lowerBound.toLong, - upperBound.toLong, - numPartitions.toInt) - } - val parts = JDBCRelation.columnPartition(partitionInfo) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(url, table, parts, properties)(sqlContext) - } -} - private[sql] case class JDBCRelation( url: String, table: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala new file mode 100644 index 000000000000..039c13bf163c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +import java.sql.{Connection, DriverManager, PreparedStatement} +import java.util.Properties + +import scala.util.Try + +import org.apache.spark.Logging +import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Util functions for JDBC tables. + */ +object JdbcUtils extends Logging { + + /** + * Establishes a JDBC connection. + */ + def createConnection(url: String, connectionProperties: Properties): Connection = { + DriverManager.getConnection(url, connectionProperties) + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, table: String): Boolean = { + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems, considering "table" could also include the database name. + Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + } + + /** + * Drops a table from the JDBC database. + */ + def dropTable(conn: Connection, table: String): Unit = { + conn.prepareStatement(s"DROP TABLE $table").executeUpdate() + } + + /** + * Returns a PreparedStatement that inserts a row into table via conn. + */ + def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { + val sql = new StringBuilder(s"INSERT INTO $table VALUES (") + var fieldsLeft = rddSchema.fields.length + while (fieldsLeft > 0) { + sql.append("?") + if (fieldsLeft > 1) sql.append(", ") else sql.append(")") + fieldsLeft = fieldsLeft - 1 + } + conn.prepareStatement(sql.toString()) + } + + /** + * Saves a partition of a DataFrame to the JDBC database. This is done in + * a single database transaction in order to avoid repeatedly inserting + * data as much as possible. + * + * It is still theoretically possible for rows in a DataFrame to be + * inserted into the database more than once if a stage somehow fails after + * the commit occurs but before the stage can return successfully. + * + * This is not a closure inside saveTable() because apparently cosmetic + * implementation changes elsewhere might easily render such a closure + * non-Serializable. Instead, we explicitly close over all variables that + * are used. + */ + def savePartition( + getConnection: () => Connection, + table: String, + iterator: Iterator[Row], + rddSchema: StructType, + nullTypes: Array[Int]): Iterator[Byte] = { + val conn = getConnection() + var committed = false + try { + conn.setAutoCommit(false) // Everything in the same db transaction. + val stmt = insertStatement(conn, table, rddSchema) + try { + while (iterator.hasNext) { + val row = iterator.next() + val numFields = rddSchema.fields.length + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + stmt.setNull(i + 1, nullTypes(i)) + } else { + rddSchema.fields(i).dataType match { + case IntegerType => stmt.setInt(i + 1, row.getInt(i)) + case LongType => stmt.setLong(i + 1, row.getLong(i)) + case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) + case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) + case ShortType => stmt.setInt(i + 1, row.getShort(i)) + case ByteType => stmt.setInt(i + 1, row.getByte(i)) + case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) + case StringType => stmt.setString(i + 1, row.getString(i)) + case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) + case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) + case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) + case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case _ => throw new IllegalArgumentException( + s"Can't translate non-null value for field $i") + } + } + i = i + 1 + } + stmt.executeUpdate() + } + } finally { + stmt.close() + } + conn.commit() + committed = true + } finally { + if (!committed) { + // The stage must fail. We got here through an exception path, so + // let the exception through unless rollback() or close() want to + // tell the user about another problem. + conn.rollback() + conn.close() + } else { + // The stage must succeed. We cannot propagate any exception close() might throw. + try { + conn.close() + } catch { + case e: Exception => logWarning("Transaction succeeded, but closing failed", e) + } + } + } + Array[Byte]().iterator + } + + /** + * Compute the schema string for this RDD. + */ + def schemaString(df: DataFrame, url: String): String = { + val sb = new StringBuilder() + val dialect = JdbcDialects.get(url) + df.schema.fields foreach { field => { + val name = field.name + val typ: String = + dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "BYTE" + case BooleanType => "BIT(1)" + case StringType => "TEXT" + case BinaryType => "BLOB" + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + }) + val nullable = if (field.nullable) "" else "NOT NULL" + sb.append(s", $name $typ $nullable") + }} + if (sb.length < 2) "" else sb.substring(2) + } + + /** + * Saves the RDD to the database in a single transaction. + */ + def saveTable( + df: DataFrame, + url: String, + table: String, + properties: Properties = new Properties()) { + val dialect = JdbcDialects.get(url) + val nullTypes: Array[Int] = df.schema.fields.map { field => + dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( + field.dataType match { + case IntegerType => java.sql.Types.INTEGER + case LongType => java.sql.Types.BIGINT + case DoubleType => java.sql.Types.DOUBLE + case FloatType => java.sql.Types.REAL + case ShortType => java.sql.Types.INTEGER + case ByteType => java.sql.Types.INTEGER + case BooleanType => java.sql.Types.BIT + case StringType => java.sql.Types.CLOB + case BinaryType => java.sql.Types.BLOB + case TimestampType => java.sql.Types.TIMESTAMP + case DateType => java.sql.Types.DATE + case t: DecimalType => java.sql.Types.DECIMAL + case _ => throw new IllegalArgumentException( + s"Can't translate null value for field $field") + }) + } + + val rddSchema = df.schema + val driver: String = DriverRegistry.getDriverClassName(url) + val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + df.foreachPartition { iterator => + savePartition(getConnection, table, iterator, rddSchema, nullTypes) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index ec5668c6b95a..b6f3410bad69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ private[sql] object InferSchema { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 5bb9e62310a5..114c8b211891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.CharArrayWriter @@ -39,9 +39,10 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "json" +class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "json" override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index d734e7e8904b..37c2b5a296c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index b8fd3b9cc150..cd68bd667c5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.json.JacksonUtils.nextUntil +import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala index fde96852ce68..005546f37dda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 975fec101d9c..4049795ed3ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.util.{Map => JMap} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala index 84f1dccfeb78..ed9e0aa65977 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} import org.apache.parquet.schema.MessageType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 4fe8a39f20ab..3542dfbae129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index b12149dcf1c9..a3fc74cf7929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConversions._ @@ -25,7 +25,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.parquet.schema._ -import org.apache.spark.sql.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLConf} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 1551afd7b7bf..2c6b914328b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala index 6ed3580af072..ccd7ebf319af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{MapData, ArrayData} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index d57b789f5c1c..9e2e232f5016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable import java.nio.ByteBuffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b6db71b5b8a6..4086a139bed7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.net.URI import java.util.logging.{Level, Logger => JLogger} @@ -51,7 +51,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "parquet" + override def shortName(): String = "parquet" override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala index 9cd0250f9c51..3191cf3d121b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.math.BigInteger import java.nio.{ByteBuffer, ByteOrder} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 3854f5bd39fb..019db34fc666 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.IOException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala similarity index 77% rename from sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 3b907e5da789..1b51a5e5c8a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.metric +package org.apache.spark.sql.execution.metric import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} @@ -93,22 +93,6 @@ private[sql] class LongSQLMetric private[metric](name: String) } } -/** - * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's - * `+=` and `add`. - */ -private[sql] class IntSQLMetric private[metric](name: String) - extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) { - - override def +=(term: Int): Unit = { - localValue.add(term) - } - - override def add(term: Int): Unit = { - localValue.add(term) - } -} - private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) @@ -121,26 +105,8 @@ private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Lon override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) } -private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] { - - override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t) - - override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue = - r1.add(r2.value) - - override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero - - override def zero: IntSQLMetricValue = new IntSQLMetricValue(0) -} - private[sql] object SQLMetrics { - def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = { - val acc = new IntSQLMetric(name) - sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } - def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { val acc = new LongSQLMetric(name) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 66237f8f1314..28fa231e722d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -18,12 +18,6 @@ package org.apache.spark.sql /** - * :: DeveloperApi :: - * An execution engine for relational query plans that runs on top Spark and returns RDDs. - * - * Note that the operators in this package are created automatically by a query planner using a - * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are - * documented here in order to make it easier for others to understand the performance - * characteristics of query plans that are generated by Spark SQL. + * The physical execution component of Spark SQL. Note that this is a private package. */ package object execution diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index cb7ca60b2fe4..49646a99d68c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 52ddf99e9266..f0b56c2eb7a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 2fd4fc658d06..0b9bad987c48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import scala.collection.mutable @@ -26,7 +26,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { @@ -51,17 +51,14 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() - @VisibleForTesting def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { _executionIdToData.toMap } - @VisibleForTesting def jobIdToExecutionId: Map[Long, Long] = synchronized { _jobIdToExecutionId.toMap } - @VisibleForTesting def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { _stageIdToStageMetrics.toMap } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala similarity index 90% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 3bba0afaf14e..0b0867f67eb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicInteger @@ -38,12 +38,12 @@ private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/ui/static" + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" private val nextTabId = new AtomicInteger(0) private def nextTabName: String = { val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL${nextId}" + if (nextId == 0) "SQL" else s"SQL$nextId" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 1ba50b95becc..ae3d752dde34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala deleted file mode 100644 index cc918c237192..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import java.sql.{Connection, DriverManager} -import java.util.Properties - -import scala.util.Try - -/** - * Util functions for JDBC tables. - */ -private[sql] object JdbcUtils { - - /** - * Establishes a JDBC connection. - */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - DriverManager.getConnection(url, connectionProperties) - } - - /** - * Returns true if the table already exists in the JDBC database. - */ - def tableExists(conn: Connection, table: String): Boolean = { - // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess - } - - /** - * Drops a table from the JDBC database. - */ - def dropTable(conn: Connection, table: String): Unit = { - conn.prepareStatement(s"DROP TABLE $table").executeUpdate() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala deleted file mode 100644 index 035e0510080f..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement, SQLFeatureNotSupportedException} -import java.util.Properties - -import scala.collection.mutable - -import org.apache.spark.Logging -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -package object jdbc { - private[sql] object JDBCWriteDetails extends Logging { - /** - * Returns a PreparedStatement that inserts a row into table via conn. - */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): - PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString) - } - - /** - * Saves a partition of a DataFrame to the JDBC database. This is done in - * a single database transaction in order to avoid repeatedly inserting - * data as much as possible. - * - * It is still theoretically possible for rows in a DataFrame to be - * inserted into the database more than once if a stage somehow fails after - * the commit occurs but before the stage can return successfully. - * - * This is not a closure inside saveTable() because apparently cosmetic - * implementation changes elsewhere might easily render such a closure - * non-Serializable. Instead, we explicitly close over all variables that - * are used. - */ - def savePartition( - getConnection: () => Connection, - table: String, - iterator: Iterator[Row], - rddSchema: StructType, - nullTypes: Array[Int]): Iterator[Byte] = { - val conn = getConnection() - var committed = false - try { - conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) - try { - while (iterator.hasNext) { - val row = iterator.next() - val numFields = rddSchema.fields.length - var i = 0 - while (i < numFields) { - if (row.isNullAt(i)) { - stmt.setNull(i + 1, nullTypes(i)) - } else { - rddSchema.fields(i).dataType match { - case IntegerType => stmt.setInt(i + 1, row.getInt(i)) - case LongType => stmt.setLong(i + 1, row.getLong(i)) - case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) - case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) - case ShortType => stmt.setInt(i + 1, row.getShort(i)) - case ByteType => stmt.setInt(i + 1, row.getByte(i)) - case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) - case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) - case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) - case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) - case _ => throw new IllegalArgumentException( - s"Can't translate non-null value for field $i") - } - } - i = i + 1 - } - stmt.executeUpdate() - } - } finally { - stmt.close() - } - conn.commit() - committed = true - } finally { - if (!committed) { - // The stage must fail. We got here through an exception path, so - // let the exception through unless rollback() or close() want to - // tell the user about another problem. - conn.rollback() - conn.close() - } else { - // The stage must succeed. We cannot propagate any exception close() might throw. - try { - conn.close() - } catch { - case e: Exception => logWarning("Transaction succeeded, but closing failed", e) - } - } - } - Array[Byte]().iterator - } - - /** - * Compute the schema string for this RDD. - */ - def schemaString(df: DataFrame, url: String): String = { - val sb = new StringBuilder() - val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { - val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) - val nullable = if (field.nullable) "" else "NOT NULL" - sb.append(s", $name $typ $nullable") - }} - if (sb.length < 2) "" else sb.substring(2) - } - - /** - * Saves the RDD to the database in a single transaction. - */ - def saveTable( - df: DataFrame, - url: String, - table: String, - properties: Properties = new Properties()) { - val dialect = JdbcDialects.get(url) - val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) - } - - val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) - df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(getConnection, table, iterator, rddSchema, nullTypes) - } - } - - } - - private [sql] class DriverWrapper(val wrapped: Driver) extends Driver { - override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) - - override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() - - override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { - wrapped.getPropertyInfo(url, info) - } - - override def getMinorVersion: Int = wrapped.getMinorVersion - - def getParentLogger: java.util.logging.Logger = - throw new SQLFeatureNotSupportedException( - s"${this.getClass().getName}.getParentLogger is not yet implemented.") - - override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) - - override def getMajorVersion: Int = wrapped.getMajorVersion - } - - /** - * java.sql.DriverManager is always loaded by bootstrap classloader, - * so it can't load JDBC drivers accessible by Spark ClassLoader. - * - * To solve the problem, drivers from user-supplied jars are wrapped - * into thin wrapper. - */ - private [sql] object DriverRegistry extends Logging { - - private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty - - def register(className: String): Unit = { - val cls = Utils.getContextOrSparkClassLoader.loadClass(className) - if (cls.getClassLoader == null) { - logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") - } else if (wrapperMap.get(className).isDefined) { - logTrace(s"Wrapper for $className already exists") - } else { - synchronized { - if (wrapperMap.get(className).isEmpty) { - val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) - DriverManager.registerDriver(wrapper) - wrapperMap(className) = wrapper - logTrace(s"Wrapper for $className registered") - } - } - } - } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } - } - -} // package object jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6bcabbab4f77..2f8417a48d32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -43,19 +43,24 @@ import org.apache.spark.util.SerializableConfiguration * This allows users to give the data source alias as the format type over the fully qualified * class name. * - * ex: parquet.DefaultSource.format = "parquet". - * * A new instance of this class with be instantiated each time a DDL call is made. + * + * @since 1.5.0 */ @DeveloperApi trait DataSourceRegister { /** * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source, - * ex: override def format(): String = "parquet" + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def format(): String = "parquet" + * }}} + * + * @since 1.5.0 */ - def format(): String + def shortName(): String } /** diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java similarity index 93% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java index daec65a5bbe5..70dec1a9d3c9 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/CompatibilityTest.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -3,7 +3,7 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated @@ -12,6 +12,6 @@ public interface CompatibilityTest { @SuppressWarnings("all") public interface Callback extends CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.parquet.test.avro.CompatibilityTest.PROTOCOL; + public static final org.apache.avro.Protocol PROTOCOL = org.apache.spark.sql.execution.datasources.parquet.test.avro.CompatibilityTest.PROTOCOL; } } \ No newline at end of file diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java similarity index 78% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index 051f1ee90386..a0a406bcd10c 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -3,7 +3,7 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { @@ -77,18 +77,18 @@ public void setNestedStringColumn(java.lang.String value) { } /** Creates a new Nested RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(); } /** Creates a new Nested RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** Creates a new Nested RecordBuilder by copying an existing Nested instance */ - public static org.apache.spark.sql.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.Nested other) { - return new org.apache.spark.sql.parquet.test.avro.Nested.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder(other); } /** @@ -102,11 +102,11 @@ public static class Builder extends org.apache.avro.specific.SpecificRecordBuild /** Creates a new Builder */ private Builder() { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); } /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder other) { super(other); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); @@ -119,8 +119,8 @@ private Builder(org.apache.spark.sql.parquet.test.avro.Nested.Builder other) { } /** Creates a Builder by copying an existing Nested instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.Nested other) { - super(org.apache.spark.sql.parquet.test.avro.Nested.SCHEMA$); + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.SCHEMA$); if (isValidValue(fields()[0], other.nested_ints_column)) { this.nested_ints_column = data().deepCopy(fields()[0].schema(), other.nested_ints_column); fieldSetFlags()[0] = true; @@ -137,7 +137,7 @@ public java.util.List getNestedIntsColumn() { } /** Sets the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedIntsColumn(java.util.List value) { validate(fields()[0], value); this.nested_ints_column = value; fieldSetFlags()[0] = true; @@ -150,7 +150,7 @@ public boolean hasNestedIntsColumn() { } /** Clears the value of the 'nested_ints_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedIntsColumn() { nested_ints_column = null; fieldSetFlags()[0] = false; return this; @@ -162,7 +162,7 @@ public java.lang.String getNestedStringColumn() { } /** Sets the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder setNestedStringColumn(java.lang.String value) { validate(fields()[1], value); this.nested_string_column = value; fieldSetFlags()[1] = true; @@ -175,7 +175,7 @@ public boolean hasNestedStringColumn() { } /** Clears the value of the 'nested_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.Nested.Builder clearNestedStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Nested.Builder clearNestedStringColumn() { nested_string_column = null; fieldSetFlags()[1] = false; return this; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java similarity index 83% rename from sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java rename to sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java index 354c9d73cca3..6198b00b1e3c 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/parquet/test/avro/ParquetAvroCompat.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -3,7 +3,7 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { @@ -25,7 +25,7 @@ public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBa @Deprecated public java.lang.String maybe_string_column; @Deprecated public java.util.List strings_column; @Deprecated public java.util.Map string_to_int_column; - @Deprecated public java.util.Map> complex_column; + @Deprecated public java.util.Map> complex_column; /** * Default constructor. Note that this does not initialize fields @@ -37,7 +37,7 @@ public ParquetAvroCompat() {} /** * All-args constructor. */ - public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { + public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { this.bool_column = bool_column; this.int_column = int_column; this.long_column = long_column; @@ -101,7 +101,7 @@ public void put(int field$, java.lang.Object value$) { case 13: maybe_string_column = (java.lang.String)value$; break; case 14: strings_column = (java.util.List)value$; break; case 15: string_to_int_column = (java.util.Map)value$; break; - case 16: complex_column = (java.util.Map>)value$; break; + case 16: complex_column = (java.util.Map>)value$; break; default: throw new org.apache.avro.AvroRuntimeException("Bad index"); } } @@ -349,7 +349,7 @@ public void setStringToIntColumn(java.util.Map> getComplexColumn() { + public java.util.Map> getComplexColumn() { return complex_column; } @@ -357,23 +357,23 @@ public java.util.Map> value) { + public void setComplexColumn(java.util.Map> value) { this.complex_column = value; } /** Creates a new ParquetAvroCompat RecordBuilder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(); } /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing Builder */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); } /** Creates a new ParquetAvroCompat RecordBuilder by copying an existing ParquetAvroCompat instance */ - public static org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - return new org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder(other); + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder(other); } /** @@ -398,15 +398,15 @@ public static class Builder extends org.apache.avro.specific.SpecificRecordBuild private java.lang.String maybe_string_column; private java.util.List strings_column; private java.util.Map string_to_int_column; - private java.util.Map> complex_column; + private java.util.Map> complex_column; /** Creates a new Builder */ private Builder() { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); } /** Creates a Builder by copying an existing Builder */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder other) { + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { super(other); if (isValidValue(fields()[0], other.bool_column)) { this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); @@ -479,8 +479,8 @@ private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder } /** Creates a Builder by copying an existing ParquetAvroCompat instance */ - private Builder(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat other) { - super(org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.SCHEMA$); + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); if (isValidValue(fields()[0], other.bool_column)) { this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); fieldSetFlags()[0] = true; @@ -557,7 +557,7 @@ public java.lang.Boolean getBoolColumn() { } /** Sets the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { validate(fields()[0], value); this.bool_column = value; fieldSetFlags()[0] = true; @@ -570,7 +570,7 @@ public boolean hasBoolColumn() { } /** Clears the value of the 'bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { fieldSetFlags()[0] = false; return this; } @@ -581,7 +581,7 @@ public java.lang.Integer getIntColumn() { } /** Sets the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { validate(fields()[1], value); this.int_column = value; fieldSetFlags()[1] = true; @@ -594,7 +594,7 @@ public boolean hasIntColumn() { } /** Clears the value of the 'int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { fieldSetFlags()[1] = false; return this; } @@ -605,7 +605,7 @@ public java.lang.Long getLongColumn() { } /** Sets the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { validate(fields()[2], value); this.long_column = value; fieldSetFlags()[2] = true; @@ -618,7 +618,7 @@ public boolean hasLongColumn() { } /** Clears the value of the 'long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { fieldSetFlags()[2] = false; return this; } @@ -629,7 +629,7 @@ public java.lang.Float getFloatColumn() { } /** Sets the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { validate(fields()[3], value); this.float_column = value; fieldSetFlags()[3] = true; @@ -642,7 +642,7 @@ public boolean hasFloatColumn() { } /** Clears the value of the 'float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { fieldSetFlags()[3] = false; return this; } @@ -653,7 +653,7 @@ public java.lang.Double getDoubleColumn() { } /** Sets the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { validate(fields()[4], value); this.double_column = value; fieldSetFlags()[4] = true; @@ -666,7 +666,7 @@ public boolean hasDoubleColumn() { } /** Clears the value of the 'double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { fieldSetFlags()[4] = false; return this; } @@ -677,7 +677,7 @@ public java.nio.ByteBuffer getBinaryColumn() { } /** Sets the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { validate(fields()[5], value); this.binary_column = value; fieldSetFlags()[5] = true; @@ -690,7 +690,7 @@ public boolean hasBinaryColumn() { } /** Clears the value of the 'binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { binary_column = null; fieldSetFlags()[5] = false; return this; @@ -702,7 +702,7 @@ public java.lang.String getStringColumn() { } /** Sets the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { validate(fields()[6], value); this.string_column = value; fieldSetFlags()[6] = true; @@ -715,7 +715,7 @@ public boolean hasStringColumn() { } /** Clears the value of the 'string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { string_column = null; fieldSetFlags()[6] = false; return this; @@ -727,7 +727,7 @@ public java.lang.Boolean getMaybeBoolColumn() { } /** Sets the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { validate(fields()[7], value); this.maybe_bool_column = value; fieldSetFlags()[7] = true; @@ -740,7 +740,7 @@ public boolean hasMaybeBoolColumn() { } /** Clears the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { maybe_bool_column = null; fieldSetFlags()[7] = false; return this; @@ -752,7 +752,7 @@ public java.lang.Integer getMaybeIntColumn() { } /** Sets the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { validate(fields()[8], value); this.maybe_int_column = value; fieldSetFlags()[8] = true; @@ -765,7 +765,7 @@ public boolean hasMaybeIntColumn() { } /** Clears the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { maybe_int_column = null; fieldSetFlags()[8] = false; return this; @@ -777,7 +777,7 @@ public java.lang.Long getMaybeLongColumn() { } /** Sets the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { validate(fields()[9], value); this.maybe_long_column = value; fieldSetFlags()[9] = true; @@ -790,7 +790,7 @@ public boolean hasMaybeLongColumn() { } /** Clears the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { maybe_long_column = null; fieldSetFlags()[9] = false; return this; @@ -802,7 +802,7 @@ public java.lang.Float getMaybeFloatColumn() { } /** Sets the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { validate(fields()[10], value); this.maybe_float_column = value; fieldSetFlags()[10] = true; @@ -815,7 +815,7 @@ public boolean hasMaybeFloatColumn() { } /** Clears the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { maybe_float_column = null; fieldSetFlags()[10] = false; return this; @@ -827,7 +827,7 @@ public java.lang.Double getMaybeDoubleColumn() { } /** Sets the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { validate(fields()[11], value); this.maybe_double_column = value; fieldSetFlags()[11] = true; @@ -840,7 +840,7 @@ public boolean hasMaybeDoubleColumn() { } /** Clears the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { maybe_double_column = null; fieldSetFlags()[11] = false; return this; @@ -852,7 +852,7 @@ public java.nio.ByteBuffer getMaybeBinaryColumn() { } /** Sets the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { validate(fields()[12], value); this.maybe_binary_column = value; fieldSetFlags()[12] = true; @@ -865,7 +865,7 @@ public boolean hasMaybeBinaryColumn() { } /** Clears the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { maybe_binary_column = null; fieldSetFlags()[12] = false; return this; @@ -877,7 +877,7 @@ public java.lang.String getMaybeStringColumn() { } /** Sets the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { validate(fields()[13], value); this.maybe_string_column = value; fieldSetFlags()[13] = true; @@ -890,7 +890,7 @@ public boolean hasMaybeStringColumn() { } /** Clears the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { maybe_string_column = null; fieldSetFlags()[13] = false; return this; @@ -902,7 +902,7 @@ public java.util.List getStringsColumn() { } /** Sets the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { validate(fields()[14], value); this.strings_column = value; fieldSetFlags()[14] = true; @@ -915,7 +915,7 @@ public boolean hasStringsColumn() { } /** Clears the value of the 'strings_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { strings_column = null; fieldSetFlags()[14] = false; return this; @@ -927,7 +927,7 @@ public java.util.Map getStringToIntColumn() } /** Sets the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { validate(fields()[15], value); this.string_to_int_column = value; fieldSetFlags()[15] = true; @@ -940,19 +940,19 @@ public boolean hasStringToIntColumn() { } /** Clears the value of the 'string_to_int_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { string_to_int_column = null; fieldSetFlags()[15] = false; return this; } /** Gets the value of the 'complex_column' field */ - public java.util.Map> getComplexColumn() { + public java.util.Map> getComplexColumn() { return complex_column; } /** Sets the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setComplexColumn(java.util.Map> value) { validate(fields()[16], value); this.complex_column = value; fieldSetFlags()[16] = true; @@ -965,7 +965,7 @@ public boolean hasComplexColumn() { } /** Clears the value of the 'complex_column' field */ - public org.apache.spark.sql.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { complex_column = null; fieldSetFlags()[16] = false; return this; @@ -991,7 +991,7 @@ public ParquetAvroCompat build() { record.maybe_string_column = fieldSetFlags()[13] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[13]); record.strings_column = fieldSetFlags()[14] ? this.strings_column : (java.util.List) defaultValue(fields()[14]); record.string_to_int_column = fieldSetFlags()[15] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[15]); - record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); + record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); return record; } catch (Exception e) { throw new org.apache.avro.AvroRuntimeException(e); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c49f256be550..adbd95197d7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -25,8 +25,8 @@ import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ -import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 92022ff23d2c..73d562189781 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} @@ -28,7 +28,7 @@ import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} -import org.apache.spark.sql.json.InferSchema.compatibleType +import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.types._ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 369df5653060..6b62c9a003df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.json +package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index bfa427349ff6..4d9c07bb7a57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer import java.util.{List => JList, Map => JMap} @@ -25,7 +25,7 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter -import org.apache.spark.sql.parquet.test.avro.{Nested, ParquetAvroCompat} +import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, SQLContext} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 57478931cd50..68f35b1f3aa8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import scala.collection.JavaConversions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index b6a7c4fbddbd..7dd9680d8cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index b415da5b8c13..ee925afe0850 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -373,7 +373,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest { // _temporary should be missing if direct output committer works. try { configuration.set("spark.sql.parquet.output.committer.class", - "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + classOf[DirectParquetOutputCommitter].getCanonicalName) sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2eef10189f11..73152de24475 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 5c65a8ec57f0..5e6d9c1cd44a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 4a0b3b60f419..8f06de7ce7c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 64e94056f209..3c6e54db4bca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import java.io.File diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 1c532d78790d..92b1d822172d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.parquet +package org.apache.spark.sql.execution.datasources.parquet import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, SQLContext} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala similarity index 92% rename from sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d22160f5384f..953284c98b20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.metric +package org.apache.spark.sql.execution.metric import java.io.{ByteArrayInputStream, ByteArrayOutputStream} @@ -41,16 +41,6 @@ class SQLMetricsSuite extends SparkFunSuite { } } - test("IntSQLMetric should not box Int") { - val l = SQLMetrics.createIntMetric(TestSQLContext.sparkContext, "Int") - val f = () => { l += 1 } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") - } - } - test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. val l = TestSQLContext.sparkContext.accumulator(0L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 69a561e16aa1..41dd1896c15d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.ui +package org.apache.spark.sql.execution.ui import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 1907e643c85d..562c27906704 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -51,7 +51,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -75,7 +75,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -92,7 +92,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -107,7 +107,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -122,7 +122,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -139,7 +139,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -158,7 +158,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -175,7 +175,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -188,7 +188,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS @@ -199,7 +199,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql( s""" |CREATE TEMPORARY TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource + |USING json |OPTIONS ( | path '${path.toString}' |) AS diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 1a4d41b02ca6..392da0b0826b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -20,9 +20,37 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType} + +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("should fail to load ORC without HiveContext") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} + + class FakeSourceOne extends RelationProvider with DataSourceRegister { - def format(): String = "Fluet da Bomb" + def shortName(): String = "Fluet da Bomb" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -35,7 +63,7 @@ class FakeSourceOne extends RelationProvider with DataSourceRegister { class FakeSourceTwo extends RelationProvider with DataSourceRegister { - def format(): String = "Fluet da Bomb" + def shortName(): String = "Fluet da Bomb" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -48,7 +76,7 @@ class FakeSourceTwo extends RelationProvider with DataSourceRegister { class FakeSourceThree extends RelationProvider with DataSourceRegister { - def format(): String = "gathering quorum" + def shortName(): String = "gathering quorum" override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = new BaseRelation { @@ -58,28 +86,3 @@ class FakeSourceThree extends RelationProvider with DataSourceRegister { StructType(Seq(StructField("stringType", StringType, nullable = false))) } } -// please note that the META-INF/services had to be modified for the test directory for this to work -class DDLSourceLoadSuite extends DataSourceTest { - - test("data sources with the same name") { - intercept[RuntimeException] { - caseInsensitiveContext.read.format("Fluet da Bomb").load() - } - } - - test("load data source from format alias") { - caseInsensitiveContext.read.format("gathering quorum").load().schema == - StructType(Seq(StructField("stringType", StringType, nullable = false))) - } - - test("specify full classname with duplicate formats") { - caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") - .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) - } - - test("Loading Orc") { - intercept[ClassNotFoundException] { - caseInsensitiveContext.read.format("orc").load() - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 3cbf5467b253..27d1cd92fca1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -22,14 +22,39 @@ import org.apache.spark.sql.execution.datasources.ResolvedDataSource class ResolvedDataSourceSuite extends SparkFunSuite { - test("builtin sources") { - assert(ResolvedDataSource.lookupDataSource("jdbc") === - classOf[org.apache.spark.sql.jdbc.DefaultSource]) + test("jdbc") { + assert( + ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.jdbc") === + classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("json") === - classOf[org.apache.spark.sql.json.DefaultSource]) + test("json") { + assert( + ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.json") === + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + } - assert(ResolvedDataSource.lookupDataSource("parquet") === - classOf[org.apache.spark.sql.parquet.DefaultSource]) + test("parquet") { + assert( + ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.execution.datasources.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + assert( + ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === + classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 7198a32df4a0..ac9aaed19d56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} import org.apache.spark.sql.execution.datasources import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 0c344c63fde3..9f4f8b5789af 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow @@ -49,9 +48,9 @@ import scala.collection.JavaConversions._ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { - def format(): String = "orc" + override def shortName(): String = "orc" - def createRelation( + override def createRelation( sqlContext: SQLContext, paths: Array[String], dataSchema: Option[StructType], @@ -144,7 +143,6 @@ private[orc] class OrcOutputWriter( } } -@DeveloperApi private[sql] class OrcRelation( override val paths: Array[String], maybeDataSchema: Option[StructType], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index a45c2d957278..1fa005d5f9a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index b73d6665755d..7f36a483a396 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index f00d3754c364..80eb9f122ad9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest import org.apache.spark.sql.{Row, SQLConf, SQLContext} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 2fa7ae3fa2e1..79a136ae6f61 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c4bc60086f6e..50f02432dacc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index d280543a071d..cb4cedddbfdd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,12 +23,12 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, SaveMode, parquet} +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + override val dataSourceName: String = "parquet" import sqlContext._ import sqlContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 1813cc33226d..48c37a1fa102 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -53,7 +53,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = - classOf[org.apache.spark.sql.json.DefaultSource].getCanonicalName + classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource].getCanonicalName import sqlContext._ From fe2fb7fb7189d183a4273ad27514af4b6b461f26 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 10 Aug 2015 13:52:18 -0700 Subject: [PATCH 0151/1824] [SPARK-9620] [SQL] generated UnsafeProjection should support many columns or large exressions Currently, generated UnsafeProjection can reach 64k byte code limit of Java. This patch will split the generated expressions into multiple functions, to avoid the limitation. After this patch, we can work well with table that have up to 64k columns (hit max number of constants limit in Java), it should be enough in practice. cc rxin Author: Davies Liu Closes #8044 from davies/wider_table and squashes the following commits: 9192e6c [Davies Liu] fix generated safe projection d1ef81a [Davies Liu] fix failed tests 737b3d3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table ffcd132 [Davies Liu] address comments 1b95be4 [Davies Liu] put the generated class into sql package 77ed72d [Davies Liu] address comments 4518e17 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table 75ccd01 [Davies Liu] Merge branch 'master' of github.com:apache/spark into wider_table 495e932 [Davies Liu] support wider table with more than 1k columns for generated projections --- .../expressions/codegen/CodeGenerator.scala | 48 ++++++- .../codegen/GenerateMutableProjection.scala | 43 +----- .../codegen/GenerateSafeProjection.scala | 52 ++------ .../codegen/GenerateUnsafeProjection.scala | 122 +++++++++--------- .../codegen/GenerateUnsafeRowJoiner.scala | 2 +- .../codegen/GeneratedProjectionSuite.scala | 82 ++++++++++++ 6 files changed, 207 insertions(+), 142 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7b41c9a3f3b8..c21f4d626a74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.google.common.cache.{CacheBuilder, CacheLoader} @@ -265,6 +266,45 @@ class CodeGenContext { def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Splits the generated code of expressions into multiple functions, because function has + * 64kb code size limit in JVM + * + * @param row the variable name of row that is used by expressions + */ + def splitExpressions(row: String, expressions: Seq[String]): String = { + val blocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + for (code <- expressions) { + // We can't know how many byte code will be generated, so use the number of bytes as limit + if (blockBuilder.length > 64 * 1000) { + blocks.append(blockBuilder.toString()) + blockBuilder.clear() + } + blockBuilder.append(code) + } + blocks.append(blockBuilder.toString()) + + if (blocks.length == 1) { + // inline execution if only one block + blocks.head + } else { + val apply = freshName("apply") + val functions = blocks.zipWithIndex.map { case (body, i) => + val name = s"${apply}_$i" + val code = s""" + |private void $name(InternalRow $row) { + | $body + |} + """.stripMargin + addNewFunction(name, code) + name + } + + functions.map(name => s"$name($row);").mkString("\n") + } + } } /** @@ -289,15 +329,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def declareMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" - }.mkString + }.mkString("\n") } protected def initMutableStates(ctx: CodeGenContext): String = { - ctx.mutableStates.map(_._3).mkString + ctx.mutableStates.map(_._3).mkString("\n") } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString + ctx.addedFuntions.map { case (funcName, funcCode) => funcCode }.mkString("\n") } /** @@ -328,6 +368,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) + // Cannot be under package codegen, or fail with java.lang.InstantiationException + evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( classOf[PlatformDependent].getName, classOf[InternalRow].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ac58423cd884..b4d4df8934bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -40,7 +40,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val projectionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) @@ -65,49 +65,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu """ } } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline execution if codesize limit was not broken - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } + val allProjections = ctx.splitExpressions("i", projectionCodes) val code = s""" public Object generate($exprType[] expr) { - return new SpecificProjection(expr); + return new SpecificMutableProjection(expr); } - class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { + class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} { private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} ${declareAddedFunctions(ctx)} - public SpecificProjection($exprType[] expr) { + public SpecificMutableProjection($exprType[] expr) { expressions = expr; mutableRow = new $genericMutableRowType(${expressions.size}); ${initMutableStates(ctx)} @@ -123,12 +95,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - $projectionFuns - public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCalls - + $allProjections return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index ef08ddf041af..7ad352d7ce3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.types._ @@ -43,6 +41,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") + // These expressions could be splitted into multiple functions + ctx.addMutableState("Object[]", values, s"this.$values = null;") + val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => @@ -53,12 +54,12 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] $values[$i] = ${converter.primitive}; } """ - }.mkString("\n") - + } + val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - final Object[] $values = new Object[${schema.length}]; - $fieldWriters + this.$values = new Object[${schema.length}]; + $allFields final InternalRow $output = new $rowClass($values); """ @@ -128,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] protected def create(expressions: Seq[Expression]): Projection = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { + val expressionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) @@ -143,36 +144,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ } - // collect projections into blocks as function has 64kb codesize limit in JVM - val projectionBlocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - for (projection <- projectionCode) { - if (blockBuilder.length > 16 * 1000) { - projectionBlocks.append(blockBuilder.toString()) - blockBuilder.clear() - } - blockBuilder.append(projection) - } - projectionBlocks.append(blockBuilder.toString()) - - val (projectionFuns, projectionCalls) = { - // inline it if we have only one block - if (projectionBlocks.length == 1) { - ("", projectionBlocks.head) - } else { - ( - projectionBlocks.zipWithIndex.map { case (body, i) => - s""" - |private void apply$i(InternalRow i) { - | $body - |} - """.stripMargin - }.mkString, - projectionBlocks.indices.map(i => s"apply$i(i);").mkString("\n") - ) - } - } - + val allExpressions = ctx.splitExpressions("i", expressionCodes) val code = s""" public Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); @@ -183,6 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] private $exprType[] expressions; private $mutableRowType mutableRow; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificSafeProjection($exprType[] expr) { expressions = expr; @@ -190,12 +163,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - $projectionFuns - public Object apply(Object _i) { InternalRow i = (InternalRow) _i; - $projectionCalls - + $allExpressions return mutableRow; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index d8912df694a1..29f6a7b98175 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -41,8 +40,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName - private val PlatformDependent = classOf[PlatformDependent].getName - /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case NullType => true @@ -56,19 +53,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s" + $DecimalWriter.getSize(${ev.primitive})" + s"$DecimalWriter.getSize(${ev.primitive})" case StringType => - s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})" case BinaryType => - s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})" case CalendarIntervalType => - s" + (${ev.isNull} ? 0 : 16)" + s"${ev.isNull} ? 0 : 16" case _: StructType => - s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})" case _: ArrayType => - s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})" case _: MapType => - s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))" + s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})" case _ => "" } @@ -125,64 +122,69 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro */ private def createCodeForStruct( ctx: CodeGenContext, + row: String, inputs: Seq[GeneratedExpressionCode], inputTypes: Seq[DataType]): GeneratedExpressionCode = { + val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) + val output = ctx.freshName("convertedStruct") - ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();") + ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();") val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val numBytes = ctx.freshName("numBytes") + ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];") val cursor = ctx.freshName("cursor") + ctx.addMutableState("int", cursor, s"this.$cursor = 0;") + val tmp = ctx.freshName("tmpBuffer") - val convertedFields = inputTypes.zip(inputs).map { case (dt, input) => - createConvertCode(ctx, input, dt) - } - - val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) - val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) => - genAdditionalSize(dt, ev) - }.mkString("") - - val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => - val update = genFieldWriter(ctx, dt, ev, output, i, cursor) - if (dt.isInstanceOf[DecimalType]) { - // Can't call setNullAt() for DecimalType + val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) => + val ev = createConvertCode(ctx, input, dt) + val growBuffer = if (!UnsafeRow.isFixedLength(dt)) { + val numBytes = ctx.freshName("numBytes") s""" + int $numBytes = $cursor + (${genAdditionalSize(dt, ev)}); + if ($buffer.length < $numBytes) { + // This will not happen frequently, because the buffer is re-used. + byte[] $tmp = new byte[$numBytes * 2]; + PlatformDependent.copyMemory($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, + $tmp, PlatformDependent.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmp; + } + $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, + ${inputTypes.length}, $numBytes); + """ + } else { + "" + } + val update = dt match { + case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS => + // Can't call setNullAt() for DecimalType + s""" if (${ev.isNull}) { - $cursor += $DecimalWriter.write($output, $i, $cursor, null); + $cursor += $DecimalWriter.write($output, $i, $cursor, null); } else { - $update; + ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; } """ - } else { - s""" + case _ => + s""" if (${ev.isNull}) { $output.setNullAt($i); } else { - $update; + ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; } """ } - }.mkString("\n") + s""" + ${ev.code} + $growBuffer + $update + """ + } val code = s""" - ${convertedFields.map(_.code).mkString("\n")} - - final int $numBytes = $fixedSize $additionalSize; - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - - $output.pointTo( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, - ${inputTypes.length}, - $numBytes); - - int $cursor = $fixedSize; - - $fieldWriters + $cursor = $fixedSize; + $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); + ${ctx.splitExpressions(row, convertedFields)} """ GeneratedExpressionCode(code, "false", output) } @@ -265,17 +267,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Should we do word align? val elementSize = elementType.defaultSize s""" - $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( + PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}); $cursor += $elementSize; """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - $PlatformDependent.UNSAFE.putLong( + PlatformDependent.UNSAFE.putLong( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}.toUnscaledLong()); $cursor += 8; """ @@ -284,7 +286,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $cursor += $writer.write( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, $elements[$index]); """ } @@ -318,14 +320,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($checkNull) { // If element is null, write the negative value address into offset region. - $PlatformDependent.UNSAFE.putInt( + PlatformDependent.UNSAFE.putInt( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { - $PlatformDependent.UNSAFE.putInt( + PlatformDependent.UNSAFE.putInt( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement @@ -334,7 +336,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $output.pointTo( $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, + PlatformDependent.BYTE_ARRAY_OFFSET, $numElements, $numBytes); } @@ -400,7 +402,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldIsNull = s"$tmp.isNullAt($i)" GeneratedExpressionCode("", fieldIsNull, getFieldCode) } - val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes) + val converter = createCodeForStruct(ctx, tmp, fieldEvals, fieldTypes) val code = s""" ${input.code} UnsafeRow $output = null; @@ -427,7 +429,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { val exprEvals = expressions.map(e => e.gen(ctx)) val exprTypes = expressions.map(_.dataType) - createCodeForStruct(ctx, exprEvals, exprTypes) + createCodeForStruct(ctx, "i", exprEvals, exprTypes) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 30b51dd83fa9..8aaa5b430004 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -155,7 +155,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); """.stripMargin } - }.mkString + }.mkString("\n") // ------------------------ Finally, put everything together --------------------------- // val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala new file mode 100644 index 000000000000..8c7ee8720f7b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A test suite for generated projections + */ +class GeneratedProjectionSuite extends SparkFunSuite { + + test("generated projections on wider table") { + val N = 1000 + val wideRow1 = new GenericInternalRow((1 to N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString((i + 1).toString) + assert(i + 1 === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val r = i + 1 + val s = UTF8String.fromString((i + 1).toString) + assert(r === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(r === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(r === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs)() + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } +} From c4fd2a242228ee101904770446e3f37d49e39b76 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 10 Aug 2015 13:55:11 -0700 Subject: [PATCH 0152/1824] [SPARK-9759] [SQL] improve decimal.times() and cast(int, decimalType) This patch optimize two things: 1. passing MathContext to JavaBigDecimal.multiply/divide/reminder to do right rounding, because java.math.BigDecimal.apply(MathContext) is expensive 2. Cast integer/short/byte to decimal directly (without double) This two optimizations could speed up the end-to-end time of a aggregation (SUM(short * decimal(5, 2)) 75% (from 19s -> 10.8s) Author: Davies Liu Closes #8052 from davies/optimize_decimal and squashes the following commits: 225efad [Davies Liu] improve decimal.times() and cast(int, decimalType) --- .../spark/sql/catalyst/expressions/Cast.scala | 42 +++++++------------ .../org/apache/spark/sql/types/Decimal.scala | 12 +++--- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 946c5a9c04f1..616b9e0e65b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -155,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType) case ByteType => buildCast[Byte](_, _ != 0) case DecimalType() => - buildCast[Decimal](_, _ != Decimal.ZERO) + buildCast[Decimal](_, !_.isZero) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -315,13 +315,13 @@ case class Cast(child: Expression, dataType: DataType) case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) - case DecimalType() => + case dt: DecimalType => b => changePrecision(b.asInstanceOf[Decimal].clone(), target) - case LongType => - b => changePrecision(Decimal(b.asInstanceOf[Long]), target) - case x: NumericType => // All other numeric types can be represented precisely as Doubles + case t: IntegralType => + b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) + case x: FractionalType => b => try { - changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) + changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target) } catch { case _: NumberFormatException => null } @@ -534,10 +534,7 @@ case class Cast(child: Expression, dataType: DataType) (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - new scala.math.BigDecimal( - new java.math.BigDecimal($c.toString()))); + Decimal tmpDecimal = Decimal.apply(new java.math.BigDecimal($c.toString())); ${changePrecision("tmpDecimal", target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; @@ -546,12 +543,7 @@ case class Cast(child: Expression, dataType: DataType) case BooleanType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = null; - if ($c) { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1); - } else { - tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0); - } + Decimal tmpDecimal = $c ? Decimal.apply(1) : Decimal.apply(0); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ case DateType => @@ -561,32 +553,28 @@ case class Cast(child: Expression, dataType: DataType) // Note that we lose precision here. (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); + Decimal tmpDecimal = Decimal.apply( + scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone(); + Decimal tmpDecimal = $c.clone(); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ - case LongType => + case x: IntegralType => (c, evPrim, evNull) => s""" - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set($c); + Decimal tmpDecimal = Decimal.apply((long) $c); ${changePrecision("tmpDecimal", target, evPrim, evNull)} """ - case x: NumericType => + case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => s""" try { - org.apache.spark.sql.types.Decimal tmpDecimal = - new org.apache.spark.sql.types.Decimal().set( - scala.math.BigDecimal.valueOf((double) $c)); + Decimal tmpDecimal = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); ${changePrecision("tmpDecimal", target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 624c3f3d7fa9..d95805c24521 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -139,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { - decimalVal(MATH_CONTEXT) + decimalVal } else { - BigDecimal(longVal, _scale)(MATH_CONTEXT) + BigDecimal(longVal, _scale) } } @@ -280,13 +280,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { } // HiveTypeCoercion will take care of the precision, scale of result - def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + def * (that: Decimal): Decimal = + Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT)) def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + if (that.isZero) null + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) def remainder(that: Decimal): Decimal = this % that From 853809e948e7c5092643587a30738115b6591a59 Mon Sep 17 00:00:00 2001 From: Prabeesh K Date: Mon, 10 Aug 2015 16:33:23 -0700 Subject: [PATCH 0153/1824] [SPARK-5155] [PYSPARK] [STREAMING] Mqtt streaming support in Python This PR is based on #4229, thanks prabeesh. Closes #4229 Author: Prabeesh K Author: zsxwing Author: prabs Author: Prabeesh K Closes #7833 from zsxwing/pr4229 and squashes the following commits: 9570bec [zsxwing] Fix the variable name and check null in finally 4a9c79e [zsxwing] Fix pom.xml indentation abf5f18 [zsxwing] Merge branch 'master' into pr4229 935615c [zsxwing] Fix the flaky MQTT tests 47278c5 [zsxwing] Include the project class files 478f844 [zsxwing] Add unpack 5f8a1d4 [zsxwing] Make the maven build generate the test jar for Python MQTT tests 734db99 [zsxwing] Merge branch 'master' into pr4229 126608a [Prabeesh K] address the comments b90b709 [Prabeesh K] Merge pull request #1 from zsxwing/pr4229 d07f454 [zsxwing] Register StreamingListerner before starting StreamingContext; Revert unncessary changes; fix the python unit test a6747cb [Prabeesh K] wait for starting the receiver before publishing data 87fc677 [Prabeesh K] address the comments: 97244ec [zsxwing] Make sbt build the assembly test jar for streaming mqtt 80474d1 [Prabeesh K] fix 1f0cfe9 [Prabeesh K] python style fix e1ee016 [Prabeesh K] scala style fix a5a8f9f [Prabeesh K] added Python test 9767d82 [Prabeesh K] implemented Python-friendly class a11968b [Prabeesh K] fixed python style 795ec27 [Prabeesh K] address comments ee387ae [Prabeesh K] Fix assembly jar location of mqtt-assembly 3f4df12 [Prabeesh K] updated version b34c3c1 [prabs] adress comments 3aa7fff [prabs] Added Python streaming mqtt word count example b7d42ff [prabs] Mqtt streaming support in Python --- dev/run-tests.py | 2 + dev/sparktestsupport/modules.py | 2 + docs/streaming-programming-guide.md | 2 +- .../main/python/streaming/mqtt_wordcount.py | 58 +++++++++ external/mqtt-assembly/pom.xml | 102 +++++++++++++++ external/mqtt/pom.xml | 28 +++++ external/mqtt/src/main/assembly/assembly.xml | 44 +++++++ .../spark/streaming/mqtt/MQTTUtils.scala | 16 +++ .../streaming/mqtt/MQTTStreamSuite.scala | 118 +++--------------- .../spark/streaming/mqtt/MQTTTestUtils.scala | 111 ++++++++++++++++ pom.xml | 1 + project/SparkBuild.scala | 12 +- python/pyspark/streaming/mqtt.py | 72 +++++++++++ python/pyspark/streaming/tests.py | 106 +++++++++++++++- 14 files changed, 565 insertions(+), 109 deletions(-) create mode 100644 examples/src/main/python/streaming/mqtt_wordcount.py create mode 100644 external/mqtt-assembly/pom.xml create mode 100644 external/mqtt/src/main/assembly/assembly.xml create mode 100644 external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala create mode 100644 python/pyspark/streaming/mqtt.py diff --git a/dev/run-tests.py b/dev/run-tests.py index d1852b95bb29..f689425ee40b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -303,6 +303,8 @@ def build_spark_sbt(hadoop_version): "assembly/assembly", "streaming-kafka-assembly/assembly", "streaming-flume-assembly/assembly", + "streaming-mqtt-assembly/assembly", + "streaming-mqtt/test:assembly", "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index a9717ff9569c..d82c0cca37bc 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -181,6 +181,7 @@ def contains_file(self, filename): dependencies=[streaming], source_file_regexes=[ "external/mqtt", + "external/mqtt-assembly", ], sbt_test_goals=[ "streaming-mqtt/test", @@ -306,6 +307,7 @@ def contains_file(self, filename): streaming, streaming_kafka, streaming_flume_assembly, + streaming_mqtt, streaming_kinesis_asl ], source_file_regexes=[ diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index dbfdb619f89e..c59d936b43c8 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py new file mode 100644 index 000000000000..617ce5ea6775 --- /dev/null +++ b/examples/src/main/python/streaming/mqtt_wordcount.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + A sample wordcount with MqttStream stream + Usage: mqtt_wordcount.py + + To run this in your local machine, you need to setup a MQTT broker and publisher first, + Mosquitto is one of the open source MQTT Brokers, see + http://mosquitto.org/ + Eclipse paho project provides number of clients and utilities for working with MQTT, see + http://www.eclipse.org/paho/#getting-started + + and then run the example + `$ bin/spark-submit --jars external/mqtt-assembly/target/scala-*/\ + spark-streaming-mqtt-assembly-*.jar examples/src/main/python/streaming/mqtt_wordcount.py \ + tcp://localhost:1883 foo` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.mqtt import MQTTUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: mqtt_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingMQTTWordCount") + ssc = StreamingContext(sc, 1) + + brokerUrl = sys.argv[1] + topic = sys.argv[2] + + lines = MQTTUtils.createStream(ssc, brokerUrl, topic) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml new file mode 100644 index 000000000000..9c94473053d9 --- /dev/null +++ b/external/mqtt-assembly/pom.xml @@ -0,0 +1,102 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-mqtt-assembly_2.10 + jar + Spark Project External MQTT Assembly + http://spark.apache.org/ + + + streaming-mqtt-assembly + + + + + org.apache.spark + spark-streaming-mqtt_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-mqtt-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 0e41e5781784..69b309876a0d 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -78,5 +78,33 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + test-jar-with-dependencies + package + + single + + + + spark-streaming-mqtt-test-${project.version} + ${project.build.directory}/scala-${scala.binary.version}/ + false + + false + + src/main/assembly/assembly.xml + + + + + + diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml new file mode 100644 index 000000000000..ecab5b360eb3 --- /dev/null +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -0,0 +1,44 @@ + + + test-jar-with-dependencies + + jar + + false + + + + ${project.build.directory}/scala-${scala.binary.version}/test-classes + / + + + + + + true + test + true + + org.apache.hadoop:*:jar + org.apache.zookeeper:*:jar + org.apache.avro:*:jar + + + + + diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 1142d0f56ba3..38a1114863d1 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -74,3 +74,19 @@ object MQTTUtils { createStream(jssc.ssc, brokerUrl, topic, storageLevel) } } + +/** + * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's MQTTUtils. + */ +private class MQTTUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ): JavaDStream[String] = { + MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel) + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index c4bf5aa7869b..a6a9249db8ed 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,46 +17,30 @@ package org.apache.spark.streaming.mqtt -import java.net.{URI, ServerSocket} -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" private val framework = this.getClass.getSimpleName - private val freePort = findFreePort() - private val brokerUri = "//localhost:" + freePort private val topic = "def" - private val persistenceDir = Utils.createTempDir() private var ssc: StreamingContext = _ - private var broker: BrokerService = _ - private var connector: TransportConnector = _ + private var mqttTestUtils: MQTTTestUtils = _ before { ssc = new StreamingContext(master, framework, batchDuration) - setupMQTT() + mqttTestUtils = new MQTTTestUtils + mqttTestUtils.setup() } after { @@ -64,14 +48,17 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter ssc.stop() ssc = null } - Utils.deleteRecursively(persistenceDir) - tearDownMQTT() + if (mqttTestUtils != null) { + mqttTestUtils.teardown() + mqttTestUtils = null + } } test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream = - MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic, + StorageLevel.MEMORY_ONLY) + @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { @@ -79,89 +66,14 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter receiveMessage } } - ssc.start() - // wait for the receiver to start before publishing data, or we risk failing - // the test nondeterministically. See SPARK-4631 - waitForReceiverToStart() + ssc.start() - publishData(sendMessage) + // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + mqttTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } - - private def setupMQTT() { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt:" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - private def tearDownMQTT() { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes("utf-8")) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - client.disconnect() - client.close() - client = null - } - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - private def waitForReceiverToStart() = { - val latch = new CountDownLatch(1) - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 000000000000..1a371b700882 --- /dev/null +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.mqtt + +import java.net.{ServerSocket, URI} + +import scala.language.postfixOps + +import com.google.common.base.Charsets.UTF_8 +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.commons.lang3.RandomUtils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * Share codes for Scala and Python unit tests + */ +private class MQTTTestUtils extends Logging { + + private val persistenceDir = Utils.createTempDir() + private val brokerHost = "localhost" + private val brokerPort = findFreePort() + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + Utils.deleteRecursively(persistenceDir) + } + + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + def publishData(topic: String, data: String): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes(UTF_8)) + message.setQos(1) + message.setRetained(true) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } + } + } + } finally { + if (client != null) { + client.disconnect() + client.close() + client = null + } + } + } + +} diff --git a/pom.xml b/pom.xml index 2bcc55b040a2..8942836a7da1 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/flume-sink external/flume-assembly external/mqtt + external/mqtt-assembly external/zeromq examples repl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9a33baa7c6ce..41a85fa9de77 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -212,6 +212,9 @@ object SparkBuild extends PomBuild { /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Enable Assembly for streamingMqtt test */ + enable(inConfig(Test)(Assembly.settings))(streamingMqtt) + /* Package pyspark artifacts in a separate zip file for YARN. */ enable(PySparkAssembly.settings)(assembly) @@ -382,13 +385,16 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { s"${mName}-${v}-hadoop${hv}.jar" } }, + jarName in (Test, assembly) <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => + s"${mName}-test-${v}.jar" + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py new file mode 100644 index 000000000000..f06598971c54 --- /dev/null +++ b/python/pyspark/streaming/mqtt.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['MQTTUtils'] + + +class MQTTUtils(object): + + @staticmethod + def createStream(ssc, brokerUrl, topic, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object + :param brokerUrl: Url of remote mqtt publisher + :param topic: topic name to subscribe to + :param storageLevel: RDD storage level. + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + MQTTUtils._printErrorMsg(ssc.sparkContext) + raise e + + return DStream(jstream, ssc, UTF8Deserializer()) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's MQTT libraries not found in class path. Try one of the following. + + 1. Include the MQTT library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... +________________________________________________________________________________________________ +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5cd544b2144e..66ae3345f468 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -40,6 +40,7 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.mqtt import MQTTUtils from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream @@ -893,6 +894,68 @@ def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) +class MQTTStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(MQTTStreamTests, self).setUp() + + MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils") + self._MQTTTestUtils = MQTTTestUtilsClz.newInstance() + self._MQTTTestUtils.setup() + + def tearDown(self): + if self._MQTTTestUtils is not None: + self._MQTTTestUtils.teardown() + self._MQTTTestUtils = None + + super(MQTTStreamTests, self).tearDown() + + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _startContext(self, topic): + # Start the StreamingContext and also collect the result + stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic) + result = [] + + def getOutput(_, rdd): + for data in rdd.collect(): + result.append(data) + + stream.foreachRDD(getOutput) + self.ssc.start() + return result + + def test_mqtt_stream(self): + """Test the Python MQTT stream API.""" + sendData = "MQTT demo for spark streaming" + topic = self._randomTopic() + result = self._startContext(topic) + + def retry(): + self._MQTTTestUtils.publishData(topic, sendData) + # Because "publishData" sends duplicate messages, here we should use > 0 + self.assertTrue(len(result) > 0) + self.assertEqual(sendData, result[0]) + + # Retry it because we don't know when the receiver will start. + self._retry_or_timeout(retry) + + def _retry_or_timeout(self, test_func): + start_time = time.time() + while True: + try: + test_func() + break + except: + if time.time() - start_time > self.timeout: + raise + time.sleep(0.01) + + class KinesisStreamTests(PySparkStreamingTestCase): def test_kinesis_stream_api(self): @@ -985,7 +1048,42 @@ def search_flume_assembly_jar(): "'build/mvn package' before running this test") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " - "remove all but one") % flume_assembly_dir) + "remove all but one") % flume_assembly_dir) + else: + return jars[0] + + +def search_mqtt_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") + jars = glob.glob( + os.path.join(mqtt_assembly_dir, "target/scala-*/spark-streaming-mqtt-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT assembly JARs in %s; please " + "remove all but one") % mqtt_assembly_dir) + else: + return jars[0] + + +def search_mqtt_test_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt") + jars = glob.glob( + os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT test JARs in %s; please " + "remove all but one") % mqtt_test_dir) else: return jars[0] @@ -1012,8 +1110,12 @@ def search_kinesis_asl_assembly_jar(): if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() + mqtt_assembly_jar = search_mqtt_assembly_jar() + mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) + + jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar, + mqtt_assembly_jar, mqtt_test_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() From 3c9802d9400bea802984456683b2736a450ee17e Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Mon, 10 Aug 2015 17:17:22 -0700 Subject: [PATCH 0154/1824] [SPARK-9801] [STREAMING] Check if file exists before deleting temporary files. Spark streaming deletes the temp file and backup files without checking if they exist or not Author: Hao Zhu Closes #8082 from viadea/master and squashes the following commits: 242d05f [Hao Zhu] [SPARK-9801][Streaming]No need to check the existence of those files fd143f2 [Hao Zhu] [SPARK-9801][Streaming]Check if backupFile exists before deleting backupFile files. 087daf0 [Hao Zhu] SPARK-9801 --- .../scala/org/apache/spark/streaming/Checkpoint.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 2780d5b6adbc..6f6b449accc3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -192,7 +192,9 @@ class CheckpointWriter( + "'") // Write checkpoint to temp file - fs.delete(tempFile, true) // just in case it exists + if (fs.exists(tempFile)) { + fs.delete(tempFile, true) // just in case it exists + } val fos = fs.create(tempFile) Utils.tryWithSafeFinally { fos.write(bytes) @@ -203,7 +205,9 @@ class CheckpointWriter( // If the checkpoint file exists, back it up // If the backup exists as well, just delete it, otherwise rename will fail if (fs.exists(checkpointFile)) { - fs.delete(backupFile, true) // just in case it exists + if (fs.exists(backupFile)){ + fs.delete(backupFile, true) // just in case it exists + } if (!fs.rename(checkpointFile, backupFile)) { logWarning("Could not rename " + checkpointFile + " to " + backupFile) } From 071bbad5db1096a548c886762b611a8484a52753 Mon Sep 17 00:00:00 2001 From: Damian Guy Date: Tue, 11 Aug 2015 12:46:33 +0800 Subject: [PATCH 0155/1824] [SPARK-9340] [SQL] Fixes converting unannotated Parquet lists This PR is inspired by #8063 authored by dguy. Especially, testing Parquet files added here are all taken from that PR. **Committer who merges this PR should attribute it to "Damian Guy ".** ---- SPARK-6776 and SPARK-6777 followed `parquet-avro` to implement backwards-compatibility rules defined in `parquet-format` spec. However, both Spark SQL and `parquet-avro` neglected the following statement in `parquet-format`: > This does not affect repeated fields that are not annotated: A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor annotated by `LIST` or `MAP` should be interpreted as a required list of required elements where the element type is the type of the field. One of the consequences is that, Parquet files generated by `parquet-protobuf` containing unannotated repeated fields are not correctly converted to Catalyst arrays. This PR fixes this issue by 1. Handling unannotated repeated fields in `CatalystSchemaConverter`. 2. Converting this kind of special repeated fields to Catalyst arrays in `CatalystRowConverter`. Two special converters, `RepeatedPrimitiveConverter` and `RepeatedGroupConverter`, are added. They delegate actual conversion work to a child `elementConverter` and accumulates elements in an `ArrayBuffer`. Two extra methods, `start()` and `end()`, are added to `ParentContainerUpdater`. So that they can be used to initialize new `ArrayBuffer`s for unannotated repeated fields, and propagate converted array values to upstream. Author: Cheng Lian Closes #8070 from liancheng/spark-9340/unannotated-parquet-list and squashes the following commits: ace6df7 [Cheng Lian] Moves ParquetProtobufCompatibilitySuite f1c7bfd [Cheng Lian] Updates .rat-excludes 420ad2b [Cheng Lian] Fixes converting unannotated Parquet lists --- .rat-excludes | 1 + .../parquet/CatalystRowConverter.scala | 151 ++++++++++++++---- .../parquet/CatalystSchemaConverter.scala | 7 +- .../resources/nested-array-struct.parquet | Bin 0 -> 775 bytes .../test/resources/old-repeated-int.parquet | Bin 0 -> 389 bytes .../resources/old-repeated-message.parquet | Bin 0 -> 600 bytes .../src/test/resources/old-repeated.parquet | Bin 0 -> 432 bytes .../parquet-thrift-compat.snappy.parquet | Bin .../resources/proto-repeated-string.parquet | Bin 0 -> 411 bytes .../resources/proto-repeated-struct.parquet | Bin 0 -> 608 bytes .../proto-struct-with-array-many.parquet | Bin 0 -> 802 bytes .../resources/proto-struct-with-array.parquet | Bin 0 -> 1576 bytes .../ParquetProtobufCompatibilitySuite.scala | 91 +++++++++++ .../parquet/ParquetSchemaSuite.scala | 30 ++++ 14 files changed, 247 insertions(+), 33 deletions(-) create mode 100644 sql/core/src/test/resources/nested-array-struct.parquet create mode 100644 sql/core/src/test/resources/old-repeated-int.parquet create mode 100644 sql/core/src/test/resources/old-repeated-message.parquet create mode 100644 sql/core/src/test/resources/old-repeated.parquet mode change 100755 => 100644 sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet create mode 100644 sql/core/src/test/resources/proto-repeated-string.parquet create mode 100644 sql/core/src/test/resources/proto-repeated-struct.parquet create mode 100644 sql/core/src/test/resources/proto-struct-with-array-many.parquet create mode 100644 sql/core/src/test/resources/proto-struct-with-array.parquet create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala diff --git a/.rat-excludes b/.rat-excludes index 72771465846b..9165872b9fb2 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -94,3 +94,4 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +.*parquet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 3542dfbae129..ab5a6ddd41cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -21,11 +21,11 @@ import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder import scala.collection.JavaConversions._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} +import org.apache.parquet.schema.OriginalType.LIST import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} @@ -42,6 +42,12 @@ import org.apache.spark.unsafe.types.UTF8String * values to an [[ArrayBuffer]]. */ private[parquet] trait ParentContainerUpdater { + /** Called before a record field is being converted */ + def start(): Unit = () + + /** Called after a record field is being converted */ + def end(): Unit = () + def set(value: Any): Unit = () def setBoolean(value: Boolean): Unit = set(value) def setByte(value: Byte): Unit = set(value) @@ -55,6 +61,32 @@ private[parquet] trait ParentContainerUpdater { /** A no-op updater used for root converter (who doesn't have a parent). */ private[parquet] object NoopUpdater extends ParentContainerUpdater +private[parquet] trait HasParentContainerUpdater { + def updater: ParentContainerUpdater +} + +/** + * A convenient converter class for Parquet group types with an [[HasParentContainerUpdater]]. + */ +private[parquet] abstract class CatalystGroupConverter(val updater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater + +/** + * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types + * are handled by this converter. Parquet primitive types are only a subset of those of Spark + * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. + */ +private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUpdater) + extends PrimitiveConverter with HasParentContainerUpdater { + + override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) + override def addInt(value: Int): Unit = updater.setInt(value) + override def addLong(value: Long): Unit = updater.setLong(value) + override def addFloat(value: Float): Unit = updater.setFloat(value) + override def addDouble(value: Double): Unit = updater.setDouble(value) + override def addBinary(value: Binary): Unit = updater.set(value.getBytes) +} + /** * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s. * Since any Parquet record is also a struct, this converter can also be used as root converter. @@ -70,7 +102,7 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: StructType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { /** * Updater used together with field converters within a [[CatalystRowConverter]]. It propagates @@ -89,13 +121,11 @@ private[parquet] class CatalystRowConverter( /** * Represents the converted row object once an entire Parquet record is converted. - * - * @todo Uses [[UnsafeRow]] for better performance. */ val currentRow = new SpecificMutableRow(catalystType.map(_.dataType)) // Converters for each field. - private val fieldConverters: Array[Converter] = { + private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { parquetType.getFields.zip(catalystType).zipWithIndex.map { case ((parquetFieldType, catalystField), ordinal) => // Converted field value should be set to the `ordinal`-th cell of `currentRow` @@ -105,11 +135,19 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex) - override def end(): Unit = updater.set(currentRow) + override def end(): Unit = { + var i = 0 + while (i < currentRow.numFields) { + fieldConverters(i).updater.end() + i += 1 + } + updater.set(currentRow) + } override def start(): Unit = { var i = 0 while (i < currentRow.numFields) { + fieldConverters(i).updater.start() currentRow.setNullAt(i) i += 1 } @@ -122,20 +160,20 @@ private[parquet] class CatalystRowConverter( private def newConverter( parquetType: Type, catalystType: DataType, - updater: ParentContainerUpdater): Converter = { + updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = { catalystType match { case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => new CatalystPrimitiveConverter(updater) case ByteType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setByte(value.asInstanceOf[ByteType#InternalType]) } case ShortType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setShort(value.asInstanceOf[ShortType#InternalType]) } @@ -148,7 +186,7 @@ private[parquet] class CatalystRowConverter( case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { assert( @@ -164,13 +202,23 @@ private[parquet] class CatalystRowConverter( } case DateType => - new PrimitiveConverter { + new CatalystPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { // DateType is not specialized in `SpecificMutableRow`, have to box it here. updater.set(value.asInstanceOf[DateType#InternalType]) } } + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + case t: ArrayType if parquetType.getOriginalType != LIST => + if (parquetType.isPrimitive) { + new RepeatedPrimitiveConverter(parquetType, t.elementType, updater) + } else { + new RepeatedGroupConverter(parquetType, t.elementType, updater) + } + case t: ArrayType => new CatalystArrayConverter(parquetType.asGroupType(), t, updater) @@ -195,27 +243,11 @@ private[parquet] class CatalystRowConverter( } } - /** - * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types - * are handled by this converter. Parquet primitive types are only a subset of those of Spark - * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. - */ - private final class CatalystPrimitiveConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { - - override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) - override def addInt(value: Int): Unit = updater.setInt(value) - override def addLong(value: Long): Unit = updater.setLong(value) - override def addFloat(value: Float): Unit = updater.setFloat(value) - override def addDouble(value: Double): Unit = updater.setDouble(value) - override def addBinary(value: Binary): Unit = updater.set(value.getBytes) - } - /** * Parquet converter for strings. A dictionary is used to minimize string decoding cost. */ private final class CatalystStringConverter(updater: ParentContainerUpdater) - extends PrimitiveConverter { + extends CatalystPrimitiveConverter(updater) { private var expandedDictionary: Array[UTF8String] = null @@ -242,7 +274,7 @@ private[parquet] class CatalystRowConverter( private final class CatalystDecimalConverter( decimalType: DecimalType, updater: ParentContainerUpdater) - extends PrimitiveConverter { + extends CatalystPrimitiveConverter(updater) { // Converts decimals stored as INT32 override def addInt(value: Int): Unit = { @@ -306,7 +338,7 @@ private[parquet] class CatalystRowConverter( parquetSchema: GroupType, catalystSchema: ArrayType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentArray: ArrayBuffer[Any] = _ @@ -383,7 +415,7 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: MapType, updater: ParentContainerUpdater) - extends GroupConverter { + extends CatalystGroupConverter(updater) { private var currentKeys: ArrayBuffer[Any] = _ private var currentValues: ArrayBuffer[Any] = _ @@ -446,4 +478,61 @@ private[parquet] class CatalystRowConverter( } } } + + private trait RepeatedConverter { + private var currentArray: ArrayBuffer[Any] = _ + + protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater { + override def start(): Unit = currentArray = ArrayBuffer.empty[Any] + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) + override def set(value: Any): Unit = currentArray += value + } + } + + /** + * A primitive converter for converting unannotated repeated primitive values to required arrays + * of required primitives values. + */ + private final class RepeatedPrimitiveConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends PrimitiveConverter with RepeatedConverter with HasParentContainerUpdater { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: PrimitiveConverter = + newConverter(parquetType, catalystType, updater).asPrimitiveConverter() + + override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value) + override def addInt(value: Int): Unit = elementConverter.addInt(value) + override def addLong(value: Long): Unit = elementConverter.addLong(value) + override def addFloat(value: Float): Unit = elementConverter.addFloat(value) + override def addDouble(value: Double): Unit = elementConverter.addDouble(value) + override def addBinary(value: Binary): Unit = elementConverter.addBinary(value) + + override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict) + override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport + override def addValueFromDictionary(id: Int): Unit = elementConverter.addValueFromDictionary(id) + } + + /** + * A group converter for converting unannotated repeated group values to required arrays of + * required struct values. + */ + private final class RepeatedGroupConverter( + parquetType: Type, + catalystType: DataType, + parentUpdater: ParentContainerUpdater) + extends GroupConverter with HasParentContainerUpdater with RepeatedConverter { + + val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater) + + private val elementConverter: GroupConverter = + newConverter(parquetType, catalystType, updater).asGroupConverter() + + override def getConverter(field: Int): Converter = elementConverter.getConverter(field) + override def end(): Unit = elementConverter.end() + override def start(): Unit = elementConverter.start() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index a3fc74cf7929..275646e8181a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -100,8 +100,11 @@ private[parquet] class CatalystSchemaConverter( StructField(field.getName, convertField(field), nullable = false) case REPEATED => - throw new AnalysisException( - s"REPEATED not supported outside LIST or MAP. Type: $field") + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor + // annotated by `LIST` or `MAP` should be interpreted as a required list of required + // elements where the element type is the type of the field. + val arrayType = ArrayType(convertField(field), containsNull = false) + StructField(field.getName, arrayType, nullable = false) } } diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/nested-array-struct.parquet new file mode 100644 index 0000000000000000000000000000000000000000..41a43fa35d39685e56ba4849a16cba4bb1aa86ae GIT binary patch literal 775 zcmaKr-%G+!6vvO#=8u;i;*JSE3{j~lAyWtuV%0Q3*U(Y)B(q&>u(@?NBZ;2+Px=dc z?6I>s%N6u+cQ4=bIp1@3cBjdsBLbvCDhGte15a`Q8~~)V;d2WY3V@LYX{-@GMj#!6 z`;fvdgDblto20oxMhs#hdKzt*4*3w}iuUE6PW?b*Zs1NAv%1WfvAnT@2NhLn_L#fy zHFWX={q7*H z93@^0*O&w#e53@lJ`hFEV2=wL)V**Pb(8vc%<=-4iJz&t;n22J{%<(t!px$!DZLaV zDaOCYR1UR;Go`F89pTwFrqpgr1NlrDOs+J&f2GO;)PtpmW%OH3neOEccXHoWyO`6Bl5({+ea14&qL7DtETw`(h_426$BxCYApN S1!5siKXe$p;U(Ab7x)6M)yB5~ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/old-repeated-int.parquet new file mode 100644 index 0000000000000000000000000000000000000000..520922f73ebb75950c4e65dec6689af817cb7a33 GIT binary patch literal 389 zcmZWmO-sW-5FMkeB_3tN1_Ca@_7p=~Z@EPbSg5juTs)Ocvz0>9#NEw7ivQhdDcI1< z%;U}1booNUUY~PehCwzvumZho_zD!@TVtKoQHivd-PzgS{O4oWwfk)ZsDnB!ByvMUB7gt@Rk50{GMw}6DWn-w!#F0CGZwP` zh81XVct9peJU!6=O6yx`ZywS;EGVOwOOIsCr3p)d)y(vgv`4bce)5D7UiKlH0l6Ga0+m-EijTt zM!X)Rd#pSj~EkCao0il-K=62)db~64ZC7r8d1AwIhzFkoG;e&AbpZW#pJ&{5H literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet new file mode 100644 index 0000000000000000000000000000000000000000..213f1a90291b30a8a3161b51c38f008f3ae9f6e5 GIT binary patch literal 432 zcmZWm!D@p*5ZxNF!5+)X3PMGioUA16Ei?y9g$B|h;-#ms>Ld--Xm{5`DgF13A<&42 znSH!BJNtMWhsm50I-@h68VC$(I7}ZALYRJm-NGUo*2p;a%Z@xEJgH{;FE=Sj6^mNc zS-TAqXn-pyRtNP8Qt};84d*60yAuBru{7JUo$1)Y6%&KlJ(Uv6u!JS1zOP)cw zaM$5ewB9699EEB0jJ*18aDDn7N1N4K`fzXlnuJ~V?c^nwk}Yeo3wXox4+#3Y!pMU2 V+-`?%2{TWZ?kYh(G4~k%>JK8=aDe~- literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet old mode 100755 new mode 100644 diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/proto-repeated-string.parquet new file mode 100644 index 0000000000000000000000000000000000000000..8a7eea601d0164b1177ab90efb7116c6e4c900da GIT binary patch literal 411 zcmY*W!D_-l5S^$E62wc{umKMtR4+{f-b!vM4Q)Y6&|G?w#H^aKansF;gi?C#*Yq1Z zv6e;{_C4Mk<_)t^FrN}2UmBK6hDddy19SkO`+9soFOY8;=b|A8A$itAvJoQdBBnKK zKy>)#|`4$W^3YtqL)WN5pTmWh1ZGv$>{f|s#sCG%1VN%LJ&FyD4siH@<(8PDu@ z!?sWEU$)ao`yyr1x2MQ?k}~ewv*0eAE$3kr261?gx~fYY8oxy0auLs;o*#@41L)=X h7Au}q6}>(e6`sLs-{PvZ8BpWYeN#xd)c_*=mLHLcZu|fM literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/proto-repeated-struct.parquet new file mode 100644 index 0000000000000000000000000000000000000000..c29eee35c350e84efa2818febb2528309a8ac3ea GIT binary patch literal 608 zcma))-D<)x6vvP811a8(bSZdI%9Js*tjca=2puciK%r=F1_P}cr%>B2jf^q&4ts<> z%r1PaC9^8^%A4e$li&GFTzg<)z+K#J;DQh(TmnDm&bFDCfsEak0$H6=|yp$CW-$_F@l={DK5j1GEpn8)DX!>A+15G`Fpg} zMZREE-l#~cYPa=r6<4%c3AD?tzx2bP7Sypiu9pGoi(^0p+XD*$Y;s4$HpQOVL*dnD3MsZWzL<}ml5ZY`6p``6p3uzOR6cOQizQ3hI@;TGm z8hy+Rm>Hl!&aiC8oEI4i7`u<#dC`r5%q<5r!9eDg1IFy*c2RU=AalzBO)!wT<$y7e zS0C?=T^uJ)6ePiDIW^oM?BO`}o-pLWHOS&h@*H8x zD56?ZuNu`Fl-0Tj)YHv=x(@J>C=j)+#mj%Z_4kgdp}D_<3b zc(o7;z363$6CCr9}+-E j<-Z;KUL2!lIhl~NDb+dIHUN;6ire!D{E~S&5gg5S?rsql%Icnq5|qgD{N<#x?oCY2!n{ZAEKv64iDOJsH{F!~)4uB-v0( z@BJM;_owufA5^-l4@Z{ekVAVh>zOxi-n`LDMyq>_0q^0x8b74l?^SjJrp%q&06h8-y4iMdSJ@MDH4c~HjX3j($=&sN1 zW|q&!OYxG3d&~^8@dlzhDa$1b0`rz(6tkBD*J153G=T1;gzF$B0g1WSKnPOy6M$~KQ|W5 z5A#DO!$$Zca>TK`;67WBib&?m76{4rqTmNgwH)UCc)*uPhjcg;fc!=Tfl{N?GyS_6 z3+tZPeSOS=k#BjS>(f75Q`2EhwX*hMsK_@Kv&ZT;SydBky3mDR6_J}cL*_TtV}7>H zA+wumr}b9v46coS`}(TY;qmaR$9wg^82X@n)jvIvzps*~J`|Fl Date: Mon, 10 Aug 2015 22:04:41 -0700 Subject: [PATCH 0156/1824] [SPARK-9729] [SPARK-9363] [SQL] Use sort merge join for left and right outer join This patch adds a new `SortMergeOuterJoin` operator that performs left and right outer joins using sort merge join. It also refactors `SortMergeJoin` in order to improve performance and code clarity. Along the way, I also performed a couple pieces of minor cleanup and optimization: - Rename the `HashJoin` physical planner rule to `EquiJoinSelection`, since it's also used for non-hash joins. - Rewrite the comment at the top of `HashJoin` to better explain the precedence for choosing join operators. - Update `JoinSuite` to use `SqlTestUtils.withConf` for changing SQLConf settings. This patch incorporates several ideas from adrian-wang's patch, #5717. Closes #5717. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/7904) Author: Josh Rosen Author: Daoyuan Wang Closes #7904 from JoshRosen/outer-join-smj and squashes 1 commits. --- .../sql/catalyst/expressions/JoinedRow.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/RowIterator.scala | 93 +++++ .../spark/sql/execution/SparkStrategies.scala | 45 ++- .../joins/BroadcastNestedLoopJoin.scala | 5 +- .../sql/execution/joins/SortMergeJoin.scala | 331 +++++++++++++----- .../execution/joins/SortMergeOuterJoin.scala | 251 +++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 132 ++++--- .../sql/execution/joins/InnerJoinSuite.scala | 180 ++++++++++ .../sql/execution/joins/OuterJoinSuite.scala | 310 +++++++++++----- .../sql/execution/joins/SemiJoinSuite.scala | 125 ++++--- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- 13 files changed, 1165 insertions(+), 319 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index b76757c93523..d3560df0792e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -37,20 +37,20 @@ class JoinedRow extends InternalRow { } /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { + def apply(r1: InternalRow, r2: InternalRow): JoinedRow = { row1 = r1 row2 = r2 this } /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { + def withLeft(newLeft: InternalRow): JoinedRow = { row1 = newLeft this } /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { + def withRight(newRight: InternalRow): JoinedRow = { row2 = newRight this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f73bb0488c98..4bf00b3399e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -873,7 +873,7 @@ class SQLContext(@transient val sparkContext: SparkContext) HashAggregation :: Aggregation :: LeftSemiJoin :: - HashJoin :: + EquiJoinSelection :: InMemoryScans :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala new file mode 100644 index 000000000000..7462dbc4eba3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.NoSuchElementException + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * An internal iterator interface which presents a more restrictive API than + * [[scala.collection.Iterator]]. + * + * One major departure from the Scala iterator API is the fusing of the `hasNext()` and `next()` + * calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the + * iterator to consume the next row, whereas RowIterator combines these calls into a single + * [[advanceNext()]] method. + */ +private[sql] abstract class RowIterator { + /** + * Advance this iterator by a single row. Returns `false` if this iterator has no more rows + * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling + * [[getRow]]. + */ + def advanceNext(): Boolean + + /** + * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this + * method after [[advanceNext()]] has returned `false`. + */ + def getRow: InternalRow + + /** + * Convert this RowIterator into a [[scala.collection.Iterator]]. + */ + def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) +} + +object RowIterator { + def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = { + scalaIter match { + case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter + case _ => new RowIteratorFromScala(scalaIter) + } + } +} + +private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] { + private [this] var hasNextWasCalled: Boolean = false + private [this] var _hasNext: Boolean = false + override def hasNext: Boolean = { + // Idempotency: + if (!hasNextWasCalled) { + _hasNext = rowIter.advanceNext() + hasNextWasCalled = true + } + _hasNext + } + override def next(): InternalRow = { + if (!hasNext) throw new NoSuchElementException + hasNextWasCalled = false + rowIter.getRow + } +} + +private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) extends RowIterator { + private[this] var _next: InternalRow = null + override def advanceNext(): Boolean = { + if (scalaIter.hasNext) { + _next = scalaIter.next() + true + } else { + _next = null + false + } + } + override def getRow: InternalRow = _next + override def toScala: Iterator[InternalRow] = scalaIter +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4b9b5acea4d..1fc870d44b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -63,19 +63,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be - * evaluated by matching hash keys. + * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates + * can be evaluated by matching join keys. * - * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an - * estimated physical size smaller than the user-settable threshold - * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the - * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be - * ''broadcasted'' to all of the executors involved in the join, as a - * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. + * Join implementations are chosen with the following precedence: + * + * - Broadcast: if one side of the join has an estimated physical size that is smaller than the + * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold + * or if that side has an explicit broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side + * of the join will be broadcasted and the other side will be streamed, with no shuffling + * performed. If both sides of the join are eligible to be broadcasted then the + * - Sort merge: if the matching join keys are sortable and + * [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join + * will be used. + * - Hash: will be chosen if neither of the above optimizations apply to this join. */ - object HashJoin extends Strategy with PredicateHelper { + object EquiJoinSelection extends Strategy with PredicateHelper { private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], @@ -90,14 +94,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + + // --- Inner joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) - // If the sort merge join option is set, we want to use sort merge join prior to hashjoin - // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => val mergeJoin = @@ -115,6 +120,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil + // --- Outer joins -------------------------------------------------------------------------- + case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastHashOuterJoin( @@ -125,10 +132,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.BroadcastHashOuterJoin( leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + joins.SortMergeOuterJoin( + leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => joins.ShuffledHashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + // --- Cases where this strategy does not apply --------------------------------------------- + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 23aebf4b068b..017a44b9ca86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -65,8 +65,9 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4ae23c186cf7..6d656ea2849a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException +import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} /** * :: DeveloperApi :: @@ -38,8 +37,6 @@ case class SortMergeJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { - override protected[sql] val trackNumOfRowsEnabled = true - override def output: Seq[Attribute] = left.output ++ right.output override def outputPartitioning: Partitioning = @@ -56,117 +53,265 @@ case class SortMergeJoin( @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + protected[this] def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(schema)) + } + + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) } protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[InternalRow] { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + new RowIterator { // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) - // Mutable per row objects. + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ + private[this] var currentMatchIdx: Int = -1 + private[this] val smjScanner = new SortMergeJoinScanner( + leftKeyGenerator, + rightKeyGenerator, + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) + ) private[this] val joinRow = new JoinedRow - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 - private[this] var stop: Boolean = false - private[this] var matchKey: InternalRow = _ - - // initialize iterator - initialize() - - override final def hasNext: Boolean = nextMatchingPair() - - override final def next(): InternalRow = { - if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null - } - } - joinedRow + private[this] val resultProjection: (InternalRow) => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) } else { - // no more result - throw new NoSuchElementException + identity[InternalRow] } } - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) - } else { - leftElement = null + override def advanceNext(): Boolean = { + if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + } } - } - - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) + if (currentLeftRow != null) { + joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) + currentMatchIdx += 1 + true } else { - rightElement = null + false } } - private def initialize() = { - fetchLeft() - fetchRight() + override def getRow: InternalRow = resultProjection(joinRow) + }.toScala + } + } +} + +/** + * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]]. + * + * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] + * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false` + * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return + * the matching row from the streamed input and may call [[getBufferedMatches]] to return the + * sequence of matching rows from the buffered input (in the case of an outer join, this will return + * an empty sequence if there are no matches from the buffered input). For efficiency, both of these + * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()` + * methods. + * + * @param streamedKeyGenerator a projection that produces join keys from the streamed input. + * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. + * @param keyOrdering an ordering which can be used to compare join keys. + * @param streamedIter an input whose rows will be streamed. + * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that + * have the same join key. + */ +private[joins] class SortMergeJoinScanner( + streamedKeyGenerator: Projection, + bufferedKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + streamedIter: RowIterator, + bufferedIter: RowIterator) { + private[this] var streamedRow: InternalRow = _ + private[this] var streamedRowKey: InternalRow = _ + private[this] var bufferedRow: InternalRow = _ + // Note: this is guaranteed to never have any null columns: + private[this] var bufferedRowKey: InternalRow = _ + /** + * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty + */ + private[this] var matchJoinKey: InternalRow = _ + /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ + private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + + // Initialization (note: do _not_ want to advance streamed here). + advancedBufferedToRowWithNullFreeJoinKey() + + // --- Public methods --------------------------------------------------------------------------- + + def getStreamedRow: InternalRow = streamedRow + + def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches + + /** + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join + * results. + */ + final def findNextInnerJoinRows(): Boolean = { + while (advancedStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // The new streamed row has the same join key as the previous row, so return the same matches. + true + } else if (bufferedRow == null) { + // The streamed row's join key does not match the current batch of buffered rows and there are + // no more rows to read from the buffered iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // Advance both the streamed and buffered iterators to find the next pair of matching rows. + var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + do { + if (streamedRowKey.anyNull) { + advancedStreamed() + } else { + assert(!bufferedRowKey.anyNull) + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() + else if (comp < 0) advancedStreamed() } + } while (streamedRow != null && bufferedRow != null && comp != 0) + if (streamedRow == null || bufferedRow == null) { + // We have either hit the end of one of the iterators, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // The streamed row's join key matches the current buffered row's join, so walk through the + // buffered iterator to buffer the rest of the matching rows. + assert(comp == 0) + bufferMatchingRows() + true + } + } + } - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() - } - } - rightMatches = new CompactBuffer[InternalRow]() - if (stop) { - stop = false - // iterate the right side to buffer all rows that matches - // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 - } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey - } - } + /** + * Advances the streamed input iterator and buffers all rows from the buffered input that + * have matching keys. + * @return true if the streamed iterator returned a row, false otherwise. If this returns true, + * then [getStreamedRow and [[getBufferedMatches]] can be called to produce the outer + * join results. + */ + final def findNextOuterJoinRows(): Boolean = { + if (!advancedStreamed()) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // Matches the current group, so do nothing. + } else { + // The streamed row does not match the current group. + matchJoinKey = null + bufferedMatches.clear() + if (bufferedRow != null && !streamedRowKey.anyNull) { + // The buffered iterator could still contain matching rows, so we'll need to walk through + // it until we either find matches or pass where they would be found. + var comp = 1 + do { + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey()) + if (comp == 0) { + // We have found matches, so buffer them (this updates matchJoinKey) + bufferMatchingRows() + } else { + // We have overshot the position where the row would be found, hence no matches. } - rightMatches != null && rightMatches.size > 0 } } + // If there is a streamed input then we always return true + true } } + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. + */ + private def advancedStreamed(): Boolean = { + if (streamedIter.advanceNext()) { + streamedRow = streamedIter.getRow + streamedRowKey = streamedKeyGenerator(streamedRow) + true + } else { + streamedRow = null + streamedRowKey = null + false + } + } + + /** + * Advance the buffered iterator until we find a row with join key that does not contain nulls. + * @return true if the buffered iterator returned a row and false otherwise. + */ + private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { + var foundRow: Boolean = false + while (!foundRow && bufferedIter.advanceNext()) { + bufferedRow = bufferedIter.getRow + bufferedRowKey = bufferedKeyGenerator(bufferedRow) + foundRow = !bufferedRowKey.anyNull + } + if (!foundRow) { + bufferedRow = null + bufferedRowKey = null + false + } else { + true + } + } + + /** + * Called when the streamed and buffered join keys match in order to buffer the matching rows. + */ + private def bufferMatchingRows(): Unit = { + assert(streamedRowKey != null) + assert(!streamedRowKey.anyNull) + assert(bufferedRowKey != null) + assert(!bufferedRowKey.anyNull) + assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + // This join key may have been produced by a mutable projection, so we need to make a copy: + matchJoinKey = streamedRowKey.copy() + bufferedMatches.clear() + do { + bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + advancedBufferedToRowWithNullFreeJoinKey() + } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala new file mode 100644 index 000000000000..5326966b07a6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an sort merge outer join of two child relations. + * + * Note: this does not support full outer join yet; see SPARK-9730 for progress on this. + */ +@DeveloperApi +case class SortMergeOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + } + + override def outputPartitioning: Partitioning = joinType match { + // For left and right outer joins, the output is partitioned by the streamed input's join keys. + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + + override def outputOrdering: Seq[SortOrder] = joinType match { + // For left and right outer joins, the output is ordered by the streamed input's join keys. + case LeftOuter => requiredOrders(leftKeys) + case RightOuter => requiredOrders(rightKeys) + case x => throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { + // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. + keys.map(SortOrder(_, Ascending)) + } + + private def isUnsafeMode: Boolean = { + (codegenEnabled && unsafeEnabled + && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(schema)) + } + + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + + private def createLeftKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + private def createRightKeyGenerator(): Projection = { + if (isUnsafeMode) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + + override def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // An ordering that can be used to compare keys from both sides. + val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val boundCondition: (InternalRow) => Boolean = { + condition.map { cond => + newPredicate(cond, left.output ++ right.output) + }.getOrElse { + (r: InternalRow) => true + } + } + val resultProj: InternalRow => InternalRow = { + if (isUnsafeMode) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + + joinType match { + case LeftOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createLeftKeyGenerator(), + bufferedKeyGenerator = createRightKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(leftIter), + bufferedIter = RowIterator.fromScala(rightIter) + ) + val rightNullRow = new GenericInternalRow(right.output.length) + new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala + + case RightOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createRightKeyGenerator(), + bufferedKeyGenerator = createLeftKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(rightIter), + bufferedIter = RowIterator.fromScala(leftIter) + ) + val leftNullRow = new GenericInternalRow(left.output.length) + new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala + + case x => + throw new IllegalArgumentException( + s"SortMergeOuterJoin should not take $x as the JoinType") + } + } + } +} + + +private class LeftOuterIterator( + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var rightIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) + + private def advanceLeft(): Boolean = { + rightIdx = 0 + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withLeft(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching right rows, so return nulls for the right row + joinedRow.withRight(rightNullRow) + } else { + // Find the next row from the right input that satisfied the bound condition + if (!advanceRightUntilBoundConditionSatisfied()) { + joinedRow.withRight(rightNullRow) + } + } + true + } else { + // Left input has been exhausted + false + } + } + + private def advanceRightUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx))) + rightIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + advanceRightUntilBoundConditionSatisfied() || advanceLeft() + } + + override def getRow: InternalRow = resultProj(joinedRow) +} + +private class RightOuterIterator( + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow + ) extends RowIterator { + private[this] val joinedRow: JoinedRow = new JoinedRow() + private[this] var leftIdx: Int = 0 + assert(smjScanner.getBufferedMatches.length == 0) + + private def advanceRight(): Boolean = { + leftIdx = 0 + if (smjScanner.findNextOuterJoinRows()) { + joinedRow.withRight(smjScanner.getStreamedRow) + if (smjScanner.getBufferedMatches.isEmpty) { + // There are no matching left rows, so return nulls for the left row + joinedRow.withLeft(leftNullRow) + } else { + // Find the next row from the left input that satisfied the bound condition + if (!advanceLeftUntilBoundConditionSatisfied()) { + joinedRow.withLeft(leftNullRow) + } + } + true + } else { + // Right input has been exhausted + false + } + } + + private def advanceLeftUntilBoundConditionSatisfied(): Boolean = { + var foundMatch: Boolean = false + while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) { + foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx))) + leftIdx += 1 + } + foundMatch + } + + override def advanceNext(): Boolean = { + advanceLeftUntilBoundConditionSatisfied() || advanceRight() + } + + override def getRow: InternalRow = resultProj(joinedRow) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5bef1d896603..ae07eaf91c87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,13 +22,14 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.types.BinaryType +import org.apache.spark.sql.test.SQLTestUtils -class JoinSuite extends QueryTest with BeforeAndAfterEach { +class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Ensures tables are loaded. TestData + override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ import ctx.logicalPlanToSparkQuery @@ -37,7 +38,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -55,6 +56,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j case j: SortMergeJoin => j + case j: SortMergeOuterJoin => j } assert(operators.size === 1) @@ -66,7 +68,6 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -83,11 +84,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", @@ -97,82 +98,75 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[ShuffledHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("SortMergeJoin shouldn't work on unsortable columns") { - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - Seq( - ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + for (sortMergeJoinEnabled <- Seq(true, false)) { + withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { + Seq( + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + } } - ctx.sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - - val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - try { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[BroadcastHashOuterJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } finally { - ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } - ctx.sql("UNCACHE TABLE testData") } @@ -180,7 +174,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.HashJoin(join) + val planned = ctx.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -457,25 +451,24 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") - val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastLeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[BroadcastLeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") - - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ).foreach { + case (query, joinClass) => assertJoin(query, joinClass) + } } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) ctx.sql("UNCACHE TABLE testData") } @@ -488,6 +481,5 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala new file mode 100644 index 000000000000..ddff7cebcc17 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution._ + +class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { + + private def testInnerJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + + def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeBroadcastHashJoin(left, right, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeBroadcastHashJoin(left, right, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using ShuffledHashJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeShuffledHashJoin(left, right, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeShuffledHashJoin(left, right, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using SortMergeJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + makeSortMergeJoin(left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + { + val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, "A"), + Row(2, "B"), + Row(3, "C"), + Row(4, "D"), + Row(5, "E"), + Row(6, "F"), + Row(null, "G") + )), new StructType().add("N", IntegerType).add("L", StringType)) + + val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "d"), + Row(null, "e") + )), new StructType().add("n", IntegerType).add("l", StringType)) + + testInnerJoin( + "inner join, one match per row", + upperCaseData, + lowerCaseData, + (upperCaseData.col("N") === lowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + } + + private val testData2 = Seq( + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + ).toDF("a", "b") + + { + val left = testData2.where("a = 1") + val right = testData2.where("a = 1") + testInnerJoin( + "inner join, multiple matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2) + ) + ) + } + + { + val left = testData2.where("a = 1") + val right = testData2.where("a = 2") + testInnerJoin( + "inner join, no matches", + left, + right, + (left.col("a") === right.col("a")).expr, + Seq.empty + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 2c27da596bc4..e16f5e39aa2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -1,89 +1,221 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} - -class OuterJoinSuite extends SparkPlanTest { - - val left = Seq( - (1, 2.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") - - val right = Seq( - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") - - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("shuffled hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } - - test("broadcast hash outer join") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), - Seq( - (1, 2.0, null, null), - (2, 1.0, 2, 3.0), - (3, 3.0, null, null) - ).map(Row.fromTuple)) - - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), - Seq( - (2, 1.0, 2, 3.0), - (null, null, 3, 2.0), - (null, null, 4, 1.0) - ).map(Row.fromTuple)) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest} + +class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { + + private def testOuterJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + joinType: JoinType, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using ShuffledHashOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using SortMergeOuterJoin") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = false) + } + } + } + } + + test(s"$testName using BroadcastNestedLoopJoin (build=left)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin (build=right)") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 100.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, 1.0), + Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(0, 0.0), + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + val condition = { + And( + (left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // --- Basic outer joins ------------------------------------------------------------------------ + + testOuterJoin( + "basic left outer join", + left, + right, + LeftOuter, + condition, + Seq( + (null, null, null, null), + (1, 2.0, null, null), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null) + ) + ) + + testOuterJoin( + "basic right outer join", + left, + right, + RightOuter, + condition, + Seq( + (null, null, null, null), + (null, null, 0, 0.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (5, 1.0, 5, 3.0), + (null, null, 7, 7.0) + ) + ) + + testOuterJoin( + "basic full outer join", + left, + right, + FullOuter, + condition, + Seq( + (1, 2.0, null, null), + (null, null, 2, -1.0), + (null, null, 2, -1.0), + (2, 100.0, null, null), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null), + (null, null, 0, 0.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (null, null, 7, 7.0), + (null, null, null, null), + (null, null, null, null) + ) + ) + + // --- Both inputs empty ------------------------------------------------------------------------ + + testOuterJoin( + "left outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + LeftOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "right outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + RightOuter, + condition, + Seq.empty + ) + + testOuterJoin( + "full outer join with both inputs empty", + left.filter("false"), + right.filter("false"), + FullOuter, + condition, + Seq.empty + ) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 927e85a7db3d..4503ed251fcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,58 +17,91 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression} -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +class SemiJoinSuite extends SparkPlanTest with SQLTestUtils { -class SemiJoinSuite extends SparkPlanTest{ - val left = Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0), - (3, 3.0) - ).toDF("a", "b") + private def testLeftSemiJoin( + testName: String, + leftRows: DataFrame, + rightRows: DataFrame, + condition: Expression, + expectedAnswer: Seq[Product]): Unit = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join).foreach { + case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => + test(s"$testName using LeftSemiJoinHash") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } - val right = Seq( - (2, 3.0), - (2, 3.0), - (3, 2.0), - (4, 1.0) - ).toDF("c", "d") + test(s"$testName using BroadcastLeftSemiJoinHash") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } - val leftKeys: List[Expression] = 'a :: Nil - val rightKeys: List[Expression] = 'c :: Nil - val condition = Some(LessThan('b, 'd)) - - test("left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + test(s"$testName using LeftSemiJoinBNL") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + LeftSemiJoinBNL(left, right, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } - test("left semi join BNL") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - LeftSemiJoinBNL(left, right, condition), - Seq( - (1, 2.0), - (1, 2.0), - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) - } + val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) - test("broadcast left semi join hash") { - checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition), - Seq( - (2, 1.0), - (2, 1.0) - ).map(Row.fromTuple)) + val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + val condition = { + And( + (left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) } + + testLeftSemiJoin( + "basic test", + left, + right, + condition, + Seq( + (2, 1.0), + (2, 1.0) + ) + ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 4c11acdab9ec..106669558977 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils trait SQLTestUtils { this: SparkFunSuite => - def sqlContext: SQLContext + protected def sqlContext: SQLContext protected def configuration = sqlContext.sparkContext.hadoopConfiguration diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 567d7fa12ff1..f17177a771c3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -531,7 +531,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { HashAggregation, Aggregation, LeftSemiJoin, - HashJoin, + EquiJoinSelection, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin From 0f90d6055e5bea9ceb1d454db84f4aa1d59b284d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 10 Aug 2015 23:41:53 -0700 Subject: [PATCH 0157/1824] [SPARK-9640] [STREAMING] [TEST] Do not run Python Kinesis tests when the Kinesis assembly JAR has not been generated Author: Tathagata Das Closes #7961 from tdas/SPARK-9640 and squashes the following commits: 974ce19 [Tathagata Das] Undo changes related to SPARK-9727 004ae26 [Tathagata Das] style fixes 9bbb97d [Tathagata Das] Minor style fies e6a677e [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9640 ca90719 [Tathagata Das] Removed extra line ba9cfc7 [Tathagata Das] Improved kinesis test selection logic 88d59bd [Tathagata Das] updated test modules 871fcc8 [Tathagata Das] Fixed SparkBuild 94be631 [Tathagata Das] Fixed style b858196 [Tathagata Das] Fixed conditions and few other things based on PR comments. e292e64 [Tathagata Das] Added filters for Kinesis python tests --- python/pyspark/streaming/tests.py | 56 ++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 66ae3345f468..f0ed415f9712 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -971,8 +971,10 @@ def test_kinesis_stream_api(self): "awsAccessKey", "awsSecretKey") def test_kinesis_stream(self): - if os.environ.get('ENABLE_KINESIS_TESTS') != '1': - print("Skip test_kinesis_stream") + if not are_kinesis_tests_enabled: + sys.stderr.write( + "Skipped test_kinesis_stream (enable by setting environment variable %s=1" + % kinesis_test_environ_var) return import random @@ -1013,6 +1015,7 @@ def get_output(_, rdd): traceback.print_exc() raise finally: + self.ssc.stop(False) kinesisTestUtils.deleteStream() kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) @@ -1027,7 +1030,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " - "'build/mvn package' before running this test") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " "remove all but one") % kafka_assembly_dir) @@ -1045,7 +1048,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test") + "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " "remove all but one") % flume_assembly_dir) @@ -1095,11 +1098,7 @@ def search_kinesis_asl_assembly_jar(): os.path.join(kinesis_asl_assembly_dir, "target/scala-*/spark-streaming-kinesis-asl-assembly-*.jar")) if not jars: - raise Exception( - ("Failed to find Spark Streaming Kinesis ASL assembly jar in %s. " % - kinesis_asl_assembly_dir) + "You need to build Spark with " - "'build/sbt -Pkinesis-asl assembly/assembly streaming-kinesis-asl-assembly/assembly' " - "or 'build/mvn -Pkinesis-asl package' before running this test") + return None elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs in %s; please " "remove all but one") % kinesis_asl_assembly_dir) @@ -1107,6 +1106,10 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in KinesisTestUtils.scala +kinesis_test_environ_var = "ENABLE_KINESIS_TESTS" +are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1' + if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() @@ -1114,8 +1117,37 @@ def search_kinesis_asl_assembly_jar(): mqtt_test_jar = search_mqtt_test_jar() kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() - jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar, - mqtt_assembly_jar, mqtt_test_jar) + if kinesis_asl_assembly_jar is None: + kinesis_jar_present = False + jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar) + else: + kinesis_jar_present = True + jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar, + mqtt_test_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars - unittest.main() + testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, + CheckpointTests, KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests] + + if kinesis_jar_present is True: + testcases.append(KinesisStreamTests) + elif are_kinesis_tests_enabled is False: + sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " + "not compiled with -Pkinesis-asl profile. To run these tests, " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " + "streaming-kinesis-asl-assembly/assembly' or " + "'build/mvn -Pkinesis-asl package' before running this test.") + else: + raise Exception( + ("Failed to find Spark Streaming Kinesis assembly jar in %s. " + % kinesis_asl_assembly_dir) + + "You need to build Spark with 'build/sbt -Pkinesis-asl " + "assembly/assembly streaming-kinesis-asl-assembly/assembly'" + "or 'build/mvn -Pkinesis-asl package' before running this test.") + + sys.stderr.write("Running tests: %s \n" % (str(testcases))) + for testcase in testcases: + sys.stderr.write("[Running %s]\n" % (testcase)) + tests = unittest.TestLoader().loadTestsFromTestCase(testcase) + unittest.TextTestRunner(verbosity=2).run(tests) From 55752d88321925da815823f968128832de6fdbbb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Aug 2015 01:08:30 -0700 Subject: [PATCH 0158/1824] [SPARK-9810] [BUILD] Remove individual commit messages from the squash commit message For more information, please see the JIRA ticket and the associated dev list discussion. https://issues.apache.org/jira/browse/SPARK-9810 http://apache-spark-developers-list.1001551.n3.nabble.com/discuss-Removing-individual-commit-messages-from-the-squash-commit-message-td13295.html Author: Reynold Xin Closes #8091 from rxin/SPARK-9810. --- dev/merge_spark_pr.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index ad4b76695c9f..b9bdec3d7086 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -159,11 +159,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): merge_message_flags += ["-m", message] # The string "Closes #%s" string is required for GitHub to correctly close the PR - merge_message_flags += [ - "-m", - "Closes #%s from %s and squashes the following commits:" % (pr_num, pr_repo_desc)] - for c in commits: - merge_message_flags += ["-m", c] + merge_message_flags += ["-m", "Closes #%s from %s." % (pr_num, pr_repo_desc)] run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) From 600031ebe27473d8fffe6ea436c2149223b82896 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Aug 2015 02:41:03 -0700 Subject: [PATCH 0159/1824] [SPARK-9727] [STREAMING] [BUILD] Updated streaming kinesis SBT project name to be more consistent Author: Tathagata Das Closes #8092 from tdas/SPARK-9727 and squashes the following commits: b1b01fd [Tathagata Das] Updated streaming kinesis project name --- dev/sparktestsupport/modules.py | 4 ++-- extras/kinesis-asl/pom.xml | 2 +- project/SparkBuild.scala | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d82c0cca37bc..346452f3174e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -134,7 +134,7 @@ def contains_file(self, filename): # files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't # fail other PRs. streaming_kinesis_asl = Module( - name="kinesis-asl", + name="streaming-kinesis-asl", dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", @@ -147,7 +147,7 @@ def contains_file(self, filename): "ENABLE_KINESIS_TESTS": "1" }, sbt_test_goals=[ - "kinesis-asl/test", + "streaming-kinesis-asl/test", ] ) diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c242e7a57b9a..521b53e230c4 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -31,7 +31,7 @@ Spark Kinesis Integration - kinesis-asl + streaming-kinesis-asl diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 41a85fa9de77..cad7067ade8c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -42,8 +42,8 @@ object BuildCommons { "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, - sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", - "kinesis-asl").map(ProjectRef(buildLocation, _)) + streamingKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", + "streaming-kinesis-asl").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") From d378396f86f625f006738d87fe5dbc2ff8fd913d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Aug 2015 08:41:06 -0700 Subject: [PATCH 0160/1824] [SPARK-9815] Rename PlatformDependent.UNSAFE -> Platform. PlatformDependent.UNSAFE is way too verbose. Author: Reynold Xin Closes #8094 from rxin/SPARK-9815 and squashes the following commits: 229b603 [Reynold Xin] [SPARK-9815] Rename PlatformDependent.UNSAFE -> Platform. --- .../serializer/DummySerializerInstance.java | 6 +- .../unsafe/UnsafeShuffleExternalSorter.java | 22 +-- .../shuffle/unsafe/UnsafeShuffleWriter.java | 4 +- .../spark/unsafe/map/BytesToBytesMap.java | 20 +-- .../unsafe/sort/PrefixComparators.java | 5 +- .../unsafe/sort/UnsafeExternalSorter.java | 22 +-- .../unsafe/sort/UnsafeInMemorySorter.java | 4 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 4 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 6 +- .../UnsafeShuffleInMemorySorterSuite.java | 20 +-- .../map/AbstractBytesToBytesMapSuite.java | 94 +++++----- .../sort/UnsafeExternalSorterSuite.java | 20 +-- .../sort/UnsafeInMemorySorterSuite.java | 20 +-- .../catalyst/expressions/UnsafeArrayData.java | 51 ++---- .../catalyst/expressions/UnsafeReaders.java | 8 +- .../sql/catalyst/expressions/UnsafeRow.java | 108 +++++------ .../expressions/UnsafeRowWriters.java | 41 ++--- .../catalyst/expressions/UnsafeWriters.java | 43 ++--- .../execution/UnsafeExternalRowSorter.java | 4 +- .../expressions/codegen/CodeGenerator.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 32 ++-- .../codegen/GenerateUnsafeRowJoiner.scala | 16 +- .../expressions/stringOperations.scala | 4 +- .../GenerateUnsafeRowJoinerBitsetSuite.scala | 4 +- .../UnsafeFixedWidthAggregationMap.java | 4 +- .../sql/execution/UnsafeKVExternalSorter.java | 4 +- .../sql/execution/UnsafeRowSerializer.scala | 6 +- .../sql/execution/joins/HashedRelation.scala | 13 +- .../org/apache/spark/sql/UnsafeRowSuite.scala | 4 +- .../{PlatformDependent.java => Platform.java} | 170 ++++++++---------- .../spark/unsafe/array/ByteArrayMethods.java | 14 +- .../apache/spark/unsafe/array/LongArray.java | 6 +- .../spark/unsafe/bitset/BitSetMethods.java | 19 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 4 +- .../spark/unsafe/memory/MemoryBlock.java | 4 +- .../unsafe/memory/UnsafeMemoryAllocator.java | 6 +- .../apache/spark/unsafe/types/ByteArray.java | 10 +- .../apache/spark/unsafe/types/UTF8String.java | 30 ++-- .../unsafe/hash/Murmur3_x86_32Suite.java | 14 +- 39 files changed, 371 insertions(+), 499 deletions(-) rename unsafe/src/main/java/org/apache/spark/unsafe/{PlatformDependent.java => Platform.java} (55%) diff --git a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 0399abc63c23..0e58bb4f7101 100644 --- a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -25,7 +25,7 @@ import scala.reflect.ClassTag; import org.apache.spark.annotation.Private; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. @@ -49,7 +49,7 @@ public void flush() { try { s.flush(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } @@ -64,7 +64,7 @@ public void close() { try { s.close(); } catch (IOException e) { - PlatformDependent.throwException(e); + Platform.throwException(e); } } }; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 925b60a14588..3d1ef0c48adc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -37,7 +37,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -211,16 +211,12 @@ private void writeSortedFile(boolean isLastFile) throws IOException { final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); final Object recordPage = taskMemoryManager.getPage(recordPointer); final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); - int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage); long recordReadPosition = recordOffsetInPage + 4; // skip over record length while (dataRemaining > 0) { final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); - PlatformDependent.copyMemory( - recordPage, - recordReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); writer.write(writeBuffer, 0, toTransfer); recordReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -447,14 +443,10 @@ public void insertRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); dataPagePosition += 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - dataPagePosition, - lengthInBytes); + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, partitionId); } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 02084f9122e0..2389c28b2839 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -53,7 +53,7 @@ import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.TaskMemoryManager; @Private @@ -244,7 +244,7 @@ void insertRecordIntoSorter(Product2 record) throws IOException { assert (serializedRecordSize > 0); sorter.insertRecord( - serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 7f79cd13aab4..85b46ec8bfae 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -270,10 +270,10 @@ public boolean hasNext() { @Override public Location next() { - int totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); if (totalLength == END_OF_PAGE_MARKER) { advanceToNextPage(); - totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); + totalLength = Platform.getInt(pageBaseObject, offsetInPage); } loc.with(currentPage, offsetInPage); offsetInPage += 4 + totalLength; @@ -402,9 +402,9 @@ private void updateAddressesAndSizes(long fullKeyAddress) { private void updateAddressesAndSizes(final Object page, final long offsetInPage) { long position = offsetInPage; - final int totalLength = PlatformDependent.UNSAFE.getInt(page, position); + final int totalLength = Platform.getInt(page, position); position += 4; - keyLength = PlatformDependent.UNSAFE.getInt(page, position); + keyLength = Platform.getInt(page, position); position += 4; valueLength = totalLength - keyLength - 4; @@ -572,7 +572,7 @@ public boolean putNewKey( // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; - PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); if (memoryGranted != pageSizeBytes) { @@ -608,21 +608,21 @@ public boolean putNewKey( final long valueDataOffsetInPage = insertCursor; insertCursor += valueLengthBytes; // word used to store the value size - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset, + Platform.putInt(dataPageBaseObject, recordOffset, keyLengthBytes + valueLengthBytes + 4); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); + Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); // Copy the key - PlatformDependent.copyMemory( + Platform.copyMemory( keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes); // Copy the value - PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, + Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, valueDataOffsetInPage, valueLengthBytes); // --- Update bookeeping data structures ----------------------------------------------------- if (useOverflowPage) { // Store the end-of-page marker at the end of the data page - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); + Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); } else { pageCursor += requiredSize; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 5e002ae1b756..71b76d5ddfaa 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -20,10 +20,9 @@ import com.google.common.primitives.UnsignedLongs; import org.apache.spark.annotation.Private; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.util.Utils; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; @Private public class PrefixComparators { @@ -73,7 +72,7 @@ public static long computePrefix(byte[] bytes) { final int minLen = Math.min(bytes.length, 8); long p = 0; for (int i = 0; i < minLen; ++i) { - p |= (128L + PlatformDependent.UNSAFE.getByte(bytes, BYTE_ARRAY_OFFSET + i)) + p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) << (56 - 8 * i); } return p; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5ebbf9b068fd..9601aafe5546 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -35,7 +35,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; @@ -427,14 +427,10 @@ public void insertRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); dataPagePosition += 4; - PlatformDependent.copyMemory( - recordBaseObject, - recordBaseOffset, - dataPageBaseObject, - dataPagePosition, - lengthInBytes); + Platform.copyMemory( + recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); } @@ -493,18 +489,16 @@ public void insertKVRecord( final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); dataPagePosition += 4; - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen); + Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen); dataPagePosition += 4; - PlatformDependent.copyMemory( - keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); + Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); dataPagePosition += keyLen; - PlatformDependent.copyMemory( - valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); + Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 1e4b8a116e11..f7787e1019c2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,7 +19,7 @@ import java.util.Comparator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -164,7 +164,7 @@ public void loadNext() { final long recordPointer = sortBuffer[position]; baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length - recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4); + recordLength = Platform.getInt(baseObject, baseOffset - 4); keyPrefix = sortBuffer[position + 1]; position += 2; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index ca1ccedc93c8..4989b05d63e2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description @@ -42,7 +42,7 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; - private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( BlockManager blockManager, diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 44cf6c756d7c..e59a84ff8d11 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -28,7 +28,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Spills a list of sorted records to disk. Spill files have the following format: @@ -117,11 +117,11 @@ public void write( long recordReadPosition = baseOffset; while (dataRemaining > 0) { final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, recordReadPosition, writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), toTransfer); writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); recordReadPosition += toTransfer; diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java index 8fa72597db24..40fefe2c9d14 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java @@ -24,7 +24,7 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -34,11 +34,7 @@ public class UnsafeShuffleInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); return new String(strBytes); } @@ -74,14 +70,10 @@ public void testBasicSorting() throws Exception { for (String str : dataToSort) { final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); } @@ -98,7 +90,7 @@ public void testBasicSorting() throws Exception { Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, partitionId >= prevPartitionId); final long recordAddress = iter.packedRecordPointer.getRecordPointer(); - final int recordLength = PlatformDependent.UNSAFE.getInt( + final int recordLength = Platform.getInt( memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); final String str = getStringFromDataPage( memoryManager.getPage(recordAddress), diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index e56a3f0b6d12..1a79c20c3524 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -32,9 +32,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.*; -import org.apache.spark.unsafe.PlatformDependent; -import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; -import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET; +import org.apache.spark.unsafe.Platform; public abstract class AbstractBytesToBytesMapSuite { @@ -80,13 +78,8 @@ public void tearDown() { private static byte[] getByteArray(MemoryLocation loc, int size) { final byte[] arr = new byte[size]; - PlatformDependent.copyMemory( - loc.getBaseObject(), - loc.getBaseOffset(), - arr, - BYTE_ARRAY_OFFSET, - size - ); + Platform.copyMemory( + loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size); return arr; } @@ -108,7 +101,7 @@ private static boolean arrayEquals( long actualLengthBytes) { return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( expected, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, actualAddr.getBaseObject(), actualAddr.getBaseOffset(), expected.length @@ -124,7 +117,7 @@ public void emptyMap() { final int keyLengthInWords = 10; final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); - Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + Assert.assertFalse(map.lookup(key, Platform.BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); Assert.assertFalse(map.iterator().hasNext()); } finally { map.free(); @@ -141,14 +134,14 @@ public void setAndRetrieveAKey() { final byte[] valueData = getRandomByteArray(recordLengthWords); try { final BytesToBytesMap.Location loc = - map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); // After storing the key and value, the other location methods should return results that @@ -159,7 +152,8 @@ public void setAndRetrieveAKey() { Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); // After calling lookup() the location should still point to the correct data. - Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertTrue( + map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength()); Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); @@ -168,10 +162,10 @@ public void setAndRetrieveAKey() { try { Assert.assertTrue(loc.putNewKey( keyData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes, valueData, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, recordLengthBytes )); Assert.fail("Should not be able to set a new value for a key"); @@ -191,25 +185,25 @@ private void iteratorTestBase(boolean destructive) throws Exception { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; final BytesToBytesMap.Location loc = - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8); Assert.assertFalse(loc.isDefined()); // Ensure that we store some zero-length keys if (i % 5 == 0) { Assert.assertTrue(loc.putNewKey( null, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 0, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } else { Assert.assertTrue(loc.putNewKey( value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8 )); } @@ -228,14 +222,13 @@ private void iteratorTestBase(boolean destructive) throws Exception { Assert.assertTrue(loc.isDefined()); final MemoryLocation keyAddress = loc.getKeyAddress(); final MemoryLocation valueAddress = loc.getValueAddress(); - final long value = PlatformDependent.UNSAFE.getLong( + final long value = Platform.getLong( valueAddress.getBaseObject(), valueAddress.getBaseOffset()); final long keyLength = loc.getKeyLength(); if (keyLength == 0) { Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); } else { - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); Assert.assertEquals(value, key); } valuesSeen.set((int) value); @@ -284,16 +277,16 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes final BytesToBytesMap.Location loc = map.lookup( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH, value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH )); } @@ -308,18 +301,18 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { Assert.assertTrue(loc.isDefined()); Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getKeyAddress().getBaseObject(), loc.getKeyAddress().getBaseOffset(), key, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, KEY_LENGTH ); - PlatformDependent.copyMemory( + Platform.copyMemory( loc.getValueAddress().getBaseObject(), loc.getValueAddress().getBaseOffset(), value, - LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, VALUE_LENGTH ); for (long j : key) { @@ -354,16 +347,16 @@ public void randomizedStressTest() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -379,7 +372,8 @@ public void randomizedStressTest() { for (Map.Entry entry : expected.entrySet()) { final byte[] key = entry.getKey().array(); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -405,16 +399,16 @@ public void randomizedTestWithRecordsLargerThanPageSize() { expected.put(ByteBuffer.wrap(key), value); final BytesToBytesMap.Location loc = map.lookup( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length ); Assert.assertFalse(loc.isDefined()); Assert.assertTrue(loc.putNewKey( key, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, key.length, value, - BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, value.length )); // After calling putNewKey, the following should be true, even before calling @@ -429,7 +423,8 @@ public void randomizedTestWithRecordsLargerThanPageSize() { for (Map.Entry entry : expected.entrySet()) { final byte[] key = entry.getKey().array(); final byte[] value = entry.getValue(); - final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + final BytesToBytesMap.Location loc = + map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); Assert.assertTrue(loc.isDefined()); Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); @@ -447,12 +442,10 @@ public void failureToAllocateFirstPage() { try { final long[] emptyArray = new long[0]; final BytesToBytesMap.Location loc = - map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0); + map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0); Assert.assertFalse(loc.isDefined()); Assert.assertFalse(loc.putNewKey( - emptyArray, LONG_ARRAY_OFFSET, 0, - emptyArray, LONG_ARRAY_OFFSET, 0 - )); + emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0)); } finally { map.free(); } @@ -468,8 +461,9 @@ public void failureToGrow() { int i; for (i = 0; i < 1024; i++) { final long[] arr = new long[]{i}; - final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8); - success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8); + final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); + success = + loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8); if (!success) { break; } @@ -541,12 +535,12 @@ public void testPeakMemoryUsed() { try { for (long i = 0; i < numRecordsPerPage * 10; i++) { final long[] value = new long[]{i}; - map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey( + map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey( value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8, value, - PlatformDependent.LONG_ARRAY_OFFSET, + Platform.LONG_ARRAY_OFFSET, 8); newPeakMemory = map.getPeakMemoryUsedBytes(); if (i % numRecordsPerPage == 0) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 83049b8a21fc..445a37b83e98 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -49,7 +49,7 @@ import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -166,14 +166,14 @@ private void assertSpillFilesWereCleanedUp() { private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { final int[] arr = new int[]{ value }; - sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); + sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value); } private static void insertRecord( UnsafeExternalSorter sorter, int[] record, long prefix) throws IOException { - sorter.insertRecord(record, PlatformDependent.INT_ARRAY_OFFSET, record.length * 4, prefix); + sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); } private UnsafeExternalSorter newSorter() throws IOException { @@ -205,7 +205,7 @@ public void testSortingOnlyByPrefix() throws Exception { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); } sorter.cleanupResources(); @@ -253,7 +253,7 @@ public void spillingOccursInResponseToMemoryPressure() throws Exception { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); i++; } sorter.cleanupResources(); @@ -265,7 +265,7 @@ public void testFillingPage() throws Exception { final UnsafeExternalSorter sorter = newSorter(); byte[] record = new byte[16]; while (sorter.getNumberOfAllocatedPages() < 2) { - sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); + sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0); } sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -292,25 +292,25 @@ public void sortingRecordsThatExceedPageSize() throws Exception { iter.loadNext(); assertEquals(123, iter.getKeyPrefix()); assertEquals(smallRecord.length * 4, iter.getRecordLength()); - assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Small record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(123, iter.getKeyPrefix()); assertEquals(smallRecord.length * 4, iter.getRecordLength()); - assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(123, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Large record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(456, iter.getKeyPrefix()); assertEquals(largeRecord.length * 4, iter.getRecordLength()); - assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); // Large record assertTrue(iter.hasNext()); iter.loadNext(); assertEquals(456, iter.getKeyPrefix()); assertEquals(largeRecord.length * 4, iter.getRecordLength()); - assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + assertEquals(456, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); assertFalse(iter.hasNext()); sorter.cleanupResources(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 909500930539..778e813df6b5 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -26,7 +26,7 @@ import static org.mockito.Mockito.mock; import org.apache.spark.HashPartitioner; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -36,11 +36,7 @@ public class UnsafeInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { final byte[] strBytes = new byte[length]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, length); + Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); return new String(strBytes); } @@ -76,14 +72,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { long position = dataPage.getBaseOffset(); for (String str : dataToSort) { final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + Platform.putInt(baseObject, position, strBytes.length); position += 4; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); + Platform.copyMemory( + strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length); position += strBytes.length; } // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so @@ -113,7 +105,7 @@ public int compare(long prefix1, long prefix2) { position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { // position now points to the start of a record (which holds its length). - final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); + final int recordLength = Platform.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); final String str = getStringFromDataPage(baseObject, position + 4, recordLength); final int partitionId = hashPartitioner.getPartition(str); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 0374846d7167..501dff090313 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; @@ -59,7 +59,7 @@ public class UnsafeArrayData extends ArrayData { private int sizeInBytes; private int getElementOffset(int ordinal) { - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + ordinal * 4L); + return Platform.getInt(baseObject, baseOffset + ordinal * 4L); } private int getElementSize(int offset, int ordinal) { @@ -157,7 +157,7 @@ public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return false; - return PlatformDependent.UNSAFE.getBoolean(baseObject, baseOffset + offset); + return Platform.getBoolean(baseObject, baseOffset + offset); } @Override @@ -165,7 +165,7 @@ public byte getByte(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + offset); + return Platform.getByte(baseObject, baseOffset + offset); } @Override @@ -173,7 +173,7 @@ public short getShort(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getShort(baseObject, baseOffset + offset); + return Platform.getShort(baseObject, baseOffset + offset); } @Override @@ -181,7 +181,7 @@ public int getInt(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + return Platform.getInt(baseObject, baseOffset + offset); } @Override @@ -189,7 +189,7 @@ public long getLong(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + return Platform.getLong(baseObject, baseOffset + offset); } @Override @@ -197,7 +197,7 @@ public float getFloat(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getFloat(baseObject, baseOffset + offset); + return Platform.getFloat(baseObject, baseOffset + offset); } @Override @@ -205,7 +205,7 @@ public double getDouble(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return 0; - return PlatformDependent.UNSAFE.getDouble(baseObject, baseOffset + offset); + return Platform.getDouble(baseObject, baseOffset + offset); } @Override @@ -215,7 +215,7 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { if (offset < 0) return null; if (precision <= Decimal.MAX_LONG_DIGITS()) { - final long value = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long value = Platform.getLong(baseObject, baseOffset + offset); return Decimal.apply(value, precision, scale); } else { final byte[] bytes = getBinary(ordinal); @@ -241,12 +241,7 @@ public byte[] getBinary(int ordinal) { if (offset < 0) return null; final int size = getElementSize(offset, ordinal); final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size); + Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size); return bytes; } @@ -255,9 +250,8 @@ public CalendarInterval getInterval(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } @@ -307,27 +301,16 @@ public boolean equals(Object other) { } public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); final byte[] arrayDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - arrayDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeInBytes - ); - arrayCopy.pointTo(arrayDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + Platform.copyMemory( + baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); return arrayCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java index b521b703389d..7b03185a30e3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; public class UnsafeReaders { public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final int numElements = Platform.getInt(baseObject, baseOffset); final UnsafeArrayData array = new UnsafeArrayData(); // Skip the first 4 bytes. array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); @@ -32,9 +32,9 @@ public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. - final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final int numElements = Platform.getInt(baseObject, baseOffset); // Read the numBytes of key array in second 4 bytes. - final int keyArraySize = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + 4); + final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4); final int valueArraySize = numBytes - 8 - keyArraySize; final UnsafeArrayData keyArray = new UnsafeArrayData(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index e829acb6285f..7fd94772090d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -27,7 +27,7 @@ import java.util.Set; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -169,7 +169,7 @@ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeI * @param sizeInBytes the number of bytes valid in the byte array */ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { - pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } @Override @@ -179,7 +179,7 @@ public void setNullAt(int i) { // To preserve row equality, zero out the value when setting the column to null. // Since this row does does not currently support updates to variable-length values, we don't // have to worry about zeroing out that data. - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); + Platform.putLong(baseObject, getFieldOffset(i), 0); } @Override @@ -191,14 +191,14 @@ public void update(int ordinal, Object value) { public void setInt(int ordinal, int value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + Platform.putInt(baseObject, getFieldOffset(ordinal), value); } @Override public void setLong(int ordinal, long value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + Platform.putLong(baseObject, getFieldOffset(ordinal), value); } @Override @@ -208,28 +208,28 @@ public void setDouble(int ordinal, double value) { if (Double.isNaN(value)) { value = Double.NaN; } - PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @Override public void setBoolean(int ordinal, boolean value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + Platform.putBoolean(baseObject, getFieldOffset(ordinal), value); } @Override public void setShort(int ordinal, short value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + Platform.putShort(baseObject, getFieldOffset(ordinal), value); } @Override public void setByte(int ordinal, byte value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + Platform.putByte(baseObject, getFieldOffset(ordinal), value); } @Override @@ -239,7 +239,7 @@ public void setFloat(int ordinal, float value) { if (Float.isNaN(value)) { value = Float.NaN; } - PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } /** @@ -263,23 +263,23 @@ public void setDecimal(int ordinal, Decimal value, int precision) { long cursor = getLong(ordinal) >>> 32; assert cursor > 0 : "invalid cursor " + cursor; // zero-out the bytes - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor, 0L); - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + cursor + 8, 0L); + Platform.putLong(baseObject, baseOffset + cursor, 0L); + Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); if (value == null) { setNullAt(ordinal); // keep the offset for future update - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); + Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); } else { final BigInteger integer = value.toJavaBigDecimal().unscaledValue(); - final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, - PlatformDependent.BIG_INTEGER_MAG_OFFSET); + final int[] mag = (int[]) Platform.getObjectVolatile(integer, + Platform.BIG_INTEGER_MAG_OFFSET); assert(mag.length <= 4); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, - baseObject, baseOffset + cursor, mag.length * 4); + Platform.copyMemory( + mag, Platform.INT_ARRAY_OFFSET, baseObject, baseOffset + cursor, mag.length * 4); setLong(ordinal, (cursor << 32) | ((long) (((integer.signum() + 1) << 8) + mag.length))); } } @@ -336,43 +336,43 @@ public boolean isNullAt(int ordinal) { @Override public boolean getBoolean(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(ordinal)); + return Platform.getBoolean(baseObject, getFieldOffset(ordinal)); } @Override public byte getByte(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(ordinal)); + return Platform.getByte(baseObject, getFieldOffset(ordinal)); } @Override public short getShort(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(ordinal)); + return Platform.getShort(baseObject, getFieldOffset(ordinal)); } @Override public int getInt(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(ordinal)); + return Platform.getInt(baseObject, getFieldOffset(ordinal)); } @Override public long getLong(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(ordinal)); + return Platform.getLong(baseObject, getFieldOffset(ordinal)); } @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); + return Platform.getFloat(baseObject, getFieldOffset(ordinal)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); + return Platform.getDouble(baseObject, getFieldOffset(ordinal)); } private static byte[] EMPTY = new byte[0]; @@ -391,13 +391,13 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { assert signum >=0 && signum <= 2 : "invalid signum " + signum; int size = (int) (offsetAndSize & 0xff); int[] mag = new int[size]; - PlatformDependent.copyMemory(baseObject, baseOffset + offset, - mag, PlatformDependent.INT_ARRAY_OFFSET, size * 4); + Platform.copyMemory( + baseObject, baseOffset + offset, mag, Platform.INT_ARRAY_OFFSET, size * 4); // create a BigInteger using signum and mag BigInteger v = new BigInteger(0, EMPTY); // create the initial object - PlatformDependent.UNSAFE.putInt(v, PlatformDependent.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); - PlatformDependent.UNSAFE.putObjectVolatile(v, PlatformDependent.BIG_INTEGER_MAG_OFFSET, mag); + Platform.putInt(v, Platform.BIG_INTEGER_SIGNUM_OFFSET, signum - 1); + Platform.putObjectVolatile(v, Platform.BIG_INTEGER_MAG_OFFSET, mag); return Decimal.apply(new BigDecimal(v, scale), precision, scale); } } @@ -420,11 +420,11 @@ public byte[] getBinary(int ordinal) { final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset + offset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, size ); return bytes; @@ -438,9 +438,8 @@ public CalendarInterval getInterval(int ordinal) { } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); - final long microseconds = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + final int months = (int) Platform.getLong(baseObject, baseOffset + offset); + final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8); return new CalendarInterval(months, microseconds); } } @@ -491,14 +490,14 @@ public MapData getMap(int ordinal) { public UnsafeRow copy() { UnsafeRow rowCopy = new UnsafeRow(); final byte[] rowDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( + Platform.copyMemory( baseObject, baseOffset, rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, sizeInBytes ); - rowCopy.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); return rowCopy; } @@ -518,18 +517,13 @@ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { */ public void copyFrom(UnsafeRow row) { // copyFrom is only available for UnsafeRow created from byte array. - assert (baseObject instanceof byte[]) && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET; + assert (baseObject instanceof byte[]) && baseOffset == Platform.BYTE_ARRAY_OFFSET; if (row.sizeInBytes > this.sizeInBytes) { // resize the underlying byte[] if it's not large enough. this.baseObject = new byte[row.sizeInBytes]; } - PlatformDependent.copyMemory( - row.baseObject, - row.baseOffset, - this.baseObject, - this.baseOffset, - row.sizeInBytes - ); + Platform.copyMemory( + row.baseObject, row.baseOffset, this.baseObject, this.baseOffset, row.sizeInBytes); // update the sizeInBytes. this.sizeInBytes = row.sizeInBytes; } @@ -544,19 +538,15 @@ public void copyFrom(UnsafeRow row) { */ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { if (baseObject instanceof byte[]) { - int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset); + int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset); out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); } else { int dataRemaining = sizeInBytes; long rowReadPosition = baseOffset; while (dataRemaining > 0) { int toTransfer = Math.min(writeBuffer.length, dataRemaining); - PlatformDependent.copyMemory( - baseObject, - rowReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer); + Platform.copyMemory( + baseObject, rowReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); out.write(writeBuffer, 0, toTransfer); rowReadPosition += toTransfer; dataRemaining -= toTransfer; @@ -584,13 +574,12 @@ public boolean equals(Object other) { * Returns the underlying bytes for this UnsafeRow. */ public byte[] getBytes() { - if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET + if (baseObject instanceof byte[] && baseOffset == Platform.BYTE_ARRAY_OFFSET && (((byte[]) baseObject).length == sizeInBytes)) { return (byte[]) baseObject; } else { byte[] bytes = new byte[sizeInBytes]; - PlatformDependent.copyMemory(baseObject, baseOffset, bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); + Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return bytes; } } @@ -600,8 +589,7 @@ public byte[] getBytes() { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { - build.append(java.lang.Long.toHexString( - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i))); + build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); build.append(','); } build.append(']'); @@ -619,12 +607,6 @@ public boolean anyNull() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - baseObject, - baseOffset, - target, - targetOffset, - sizeInBytes - ); + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 28e7ec0a0f12..005351f0883e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; import org.apache.spark.unsafe.types.CalendarInterval; @@ -58,27 +58,27 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input final Object base = target.getBaseObject(); final long offset = target.getBaseOffset() + cursor; // zero-out the bytes - PlatformDependent.UNSAFE.putLong(base, offset, 0L); - PlatformDependent.UNSAFE.putLong(base, offset + 8, 0L); + Platform.putLong(base, offset, 0L); + Platform.putLong(base, offset + 8, 0L); if (input == null) { target.setNullAt(ordinal); // keep the offset and length for update int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8; - PlatformDependent.UNSAFE.putLong(base, target.getBaseOffset() + fieldOffset, + Platform.putLong(base, target.getBaseOffset() + fieldOffset, ((long) cursor) << 32); return SIZE; } final BigInteger integer = input.toJavaBigDecimal().unscaledValue(); int signum = integer.signum() + 1; - final int[] mag = (int[]) PlatformDependent.UNSAFE.getObjectVolatile(integer, - PlatformDependent.BIG_INTEGER_MAG_OFFSET); + final int[] mag = (int[]) Platform.getObjectVolatile( + integer, Platform.BIG_INTEGER_MAG_OFFSET); assert(mag.length <= 4); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(mag, PlatformDependent.INT_ARRAY_OFFSET, - base, target.getBaseOffset() + cursor, mag.length * 4); + Platform.copyMemory( + mag, Platform.INT_ARRAY_OFFSET, base, target.getBaseOffset() + cursor, mag.length * 4); // Set the fixed length portion. target.setLong(ordinal, (((long) cursor) << 32) | ((long) ((signum << 8) + mag.length))); @@ -99,8 +99,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String in // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -125,8 +124,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -167,8 +165,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow i // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -191,8 +188,8 @@ public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInter final long offset = target.getBaseOffset() + cursor; // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months); - PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds); + Platform.putLong(target.getBaseObject(), offset, input.months); + Platform.putLong(target.getBaseObject(), offset + 8, input.microseconds); // Set the fixed length portion. target.setLong(ordinal, ((long) cursor) << 32); @@ -212,12 +209,11 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayDa final long offset = target.getBaseOffset() + cursor; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + Platform.putInt(target.getBaseObject(), offset, input.numElements()); // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes to the variable length portion. @@ -247,14 +243,13 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + Platform.putInt(target.getBaseObject(), offset, input.numElements()); // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset + 4, keysNumBytes); + Platform.putInt(target.getBaseObject(), offset + 4, keysNumBytes); // zero-out the padding bytes if ((numBytes & 0x07) > 0) { - PlatformDependent.UNSAFE.putLong( - target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } // Write the bytes of key array to the variable length portion. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java index 0e8e405d055d..cd83695fca03 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.PlatformDependent; -import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -36,17 +35,11 @@ public static void writeToMemory( // zero-out the padding bytes // if ((numBytes & 0x07) > 0) { -// PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); +// Platform.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); // } // Write the UnsafeData to the target memory. - PlatformDependent.copyMemory( - inputObject, - inputOffset, - targetObject, - targetOffset, - numBytes - ); + Platform.copyMemory(inputObject, inputOffset, targetObject, targetOffset, numBytes); } public static int getRoundedSize(int size) { @@ -68,16 +61,11 @@ public static int write(Object targetObject, long targetOffset, Decimal input) { assert(numBytes <= 16); // zero-out the bytes - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, 0L); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, 0L); + Platform.putLong(targetObject, targetOffset, 0L); + Platform.putLong(targetObject, targetOffset + 8, 0L); // Write the bytes to the variable length portion. - PlatformDependent.copyMemory(bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, - targetOffset, - numBytes); - + Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); return 16; } } @@ -111,8 +99,7 @@ public static int write(Object targetObject, long targetOffset, byte[] input) { final int numBytes = input.length; // Write the bytes to the variable length portion. - writeToMemory(input, PlatformDependent.BYTE_ARRAY_OFFSET, - targetObject, targetOffset, numBytes); + writeToMemory(input, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); return getRoundedSize(numBytes); } @@ -144,11 +131,9 @@ public static int getSize(UnsafeRow input) { } public static int write(Object targetObject, long targetOffset, CalendarInterval input) { - // Write the months and microseconds fields of Interval to the variable length portion. - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, input.months); - PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, input.microseconds); - + Platform.putLong(targetObject, targetOffset, input.months); + Platform.putLong(targetObject, targetOffset + 8, input.microseconds); return 16; } } @@ -165,11 +150,11 @@ public static int write(Object targetObject, long targetOffset, UnsafeArrayData final int numBytes = input.getSizeInBytes(); // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + Platform.putInt(targetObject, targetOffset, input.numElements()); // Write the bytes to the variable length portion. - writeToMemory(input.getBaseObject(), input.getBaseOffset(), - targetObject, targetOffset + 4, numBytes); + writeToMemory( + input.getBaseObject(), input.getBaseOffset(), targetObject, targetOffset + 4, numBytes); return getRoundedSize(numBytes + 4); } @@ -190,9 +175,9 @@ public static int write(Object targetObject, long targetOffset, UnsafeMapData in final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; // write the number of elements into first 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + Platform.putInt(targetObject, targetOffset, input.numElements()); // write the numBytes of key array into second 4 bytes. - PlatformDependent.UNSAFE.putInt(targetObject, targetOffset + 4, keysNumBytes); + Platform.putInt(targetObject, targetOffset + 4, keysNumBytes); // Write the bytes of key array to the variable length portion. writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(), diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index a5ae2b973652..1d27182912c8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; @@ -157,7 +157,7 @@ public UnsafeRow next() { cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack // to re-throw the exception: - PlatformDependent.throwException(e); + Platform.throwException(e); } throw new RuntimeException("Exception should have been re-thrown in next()"); }; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c21f4d626a74..bf96248feaef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -28,7 +28,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ @@ -371,7 +371,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( - classOf[PlatformDependent].getName, + classOf[Platform].getName, classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 29f6a7b98175..b2fb91385079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -145,12 +145,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro if ($buffer.length < $numBytes) { // This will not happen frequently, because the buffer is re-used. byte[] $tmp = new byte[$numBytes * 2]; - PlatformDependent.copyMemory($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, - $tmp, PlatformDependent.BYTE_ARRAY_OFFSET, $buffer.length); + Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, + $tmp, Platform.BYTE_ARRAY_OFFSET, $buffer.length); $buffer = $tmp; } - $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, - ${inputTypes.length}, $numBytes); + $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes); """ } else { "" @@ -183,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" $cursor = $fixedSize; - $output.pointTo($buffer, PlatformDependent.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); + $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); ${ctx.splitExpressions(row, convertedFields)} """ GeneratedExpressionCode(code, "false", output) @@ -267,17 +266,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Should we do word align? val elementSize = elementType.defaultSize s""" - PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( + Platform.put${ctx.primitiveTypeName(elementType)}( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}); $cursor += $elementSize; """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - PlatformDependent.UNSAFE.putLong( + Platform.putLong( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, ${convertedElement.primitive}.toUnscaledLong()); $cursor += 8; """ @@ -286,7 +285,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $cursor += $writer.write( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + Platform.BYTE_ARRAY_OFFSET + $cursor, $elements[$index]); """ } @@ -320,23 +319,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($checkNull) { // If element is null, write the negative value address into offset region. - PlatformDependent.UNSAFE.putInt( - $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - -$cursor); + Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { - PlatformDependent.UNSAFE.putInt( - $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, - $cursor); - + Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement } } $output.pointTo( $buffer, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, $numElements, $numBytes); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 8aaa5b430004..da91ff29537b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform abstract class UnsafeRowJoiner { @@ -52,9 +52,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { - val offset = PlatformDependent.BYTE_ARRAY_OFFSET - val getLong = "PlatformDependent.UNSAFE.getLong" - val putLong = "PlatformDependent.UNSAFE.putLong" + val offset = Platform.BYTE_ARRAY_OFFSET + val getLong = "Platform.getLong" + val putLong = "Platform.putLong" val bitset1Words = (schema1.size + 63) / 64 val bitset2Words = (schema2.size + 63) / 64 @@ -96,7 +96,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U var cursor = offset + outputBitsetWords * 8 val copyFixedLengthRow1 = s""" |// Copy fixed length data for row1 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${bitset1Words * 8}, | buf, $cursor, | ${schema1.size * 8}); @@ -106,7 +106,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy fixed length portion from row 2 ----------------------- // val copyFixedLengthRow2 = s""" |// Copy fixed length data for row2 - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${bitset2Words * 8}, | buf, $cursor, | ${schema2.size * 8}); @@ -118,7 +118,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow1 = s""" |// Copy variable length data for row1 |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, | buf, $cursor, | numBytesVariableRow1); @@ -129,7 +129,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val copyVariableLengthRow2 = s""" |// Copy variable length data for row2 |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2; - |PlatformDependent.copyMemory( + |Platform.copyMemory( | obj2, offset2 + ${(bitset2Words + schema2.size) * 8}, | buf, $cursor + numBytesVariableRow1, | numBytesVariableRow2); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 76666bd6b3d2..134f1aa2af9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -1013,7 +1013,7 @@ case class Decode(bin: Expression, charset: Expression) try { ${ev.primitive} = UTF8String.fromString(new String($bytes, $charset.toString())); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); } """) } @@ -1043,7 +1043,7 @@ case class Encode(value: Expression, charset: Expression) try { ${ev.primitive} = $string.toString().getBytes($charset.toString()); } catch (java.io.UnsupportedEncodingException e) { - org.apache.spark.unsafe.PlatformDependent.throwException(e); + org.apache.spark.unsafe.Platform.throwException(e); }""") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index aff1bee99faa..796d60032e1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * A test suite for the bitset portion of the row concatenation. @@ -96,7 +96,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) - row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) + row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) row } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 00218f213054..5cce41d5a756 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -138,7 +138,7 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo unsafeGroupingKeyRow.getBaseOffset(), unsafeGroupingKeyRow.getSizeInBytes(), emptyAggregationBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, + Platform.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length ); if (!putSucceeded) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 69d6784713a2..7db6b7ff50f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -225,7 +225,7 @@ public boolean next() throws IOException { int recordLen = underlying.getRecordLength(); // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) - int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); + int keyLen = Platform.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 6c7e5cacc99e..3860c4bba9a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform /** * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as @@ -116,7 +116,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream val _rowTuple = rowTuple @@ -150,7 +150,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 953abf409f22..63d35d0f0262 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.Utils @@ -218,8 +218,8 @@ private[joins] final class UnsafeHashedRelation( var offset = loc.getValueAddress.getBaseOffset val last = loc.getValueAddress.getBaseOffset + loc.getValueLength while (offset < last) { - val numFields = PlatformDependent.UNSAFE.getInt(base, offset) - val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + val numFields = Platform.getInt(base, offset) + val sizeInBytes = Platform.getInt(base, offset + 4) offset += 8 val row = new UnsafeRow @@ -314,10 +314,11 @@ private[joins] final class UnsafeHashedRelation( in.readFully(valuesBuffer, 0, valuesSize) // put it into binary map - val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) assert(!loc.isDefined, "Duplicated key found!") - val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, - valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + val putSuceeded = loc.putNewKey( + keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize) if (!putSuceeded) { throw new IOException("Could not allocate memory to grow BytesToBytesMap") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 89bad1bfdab0..219435dff5bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -51,7 +51,7 @@ class UnsafeRowSuite extends SparkFunSuite { val bytesFromOffheapRow: Array[Byte] = { val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) try { - PlatformDependent.copyMemory( + Platform.copyMemory( arrayBackedUnsafeRow.getBaseObject, arrayBackedUnsafeRow.getBaseOffset, offheapRowPage.getBaseObject, diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java similarity index 55% rename from unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java rename to unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index b2de2a2590f0..18343efdc343 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -22,103 +22,111 @@ import sun.misc.Unsafe; -public final class PlatformDependent { +public final class Platform { - /** - * Facade in front of {@link sun.misc.Unsafe}, used to avoid directly exposing Unsafe outside of - * this package. This also lets us avoid accidental use of deprecated methods. - */ - public static final class UNSAFE { - - private UNSAFE() { } + private static final Unsafe _UNSAFE; - public static int getInt(Object object, long offset) { - return _UNSAFE.getInt(object, offset); - } + public static final int BYTE_ARRAY_OFFSET; - public static void putInt(Object object, long offset, int value) { - _UNSAFE.putInt(object, offset, value); - } + public static final int INT_ARRAY_OFFSET; - public static boolean getBoolean(Object object, long offset) { - return _UNSAFE.getBoolean(object, offset); - } + public static final int LONG_ARRAY_OFFSET; - public static void putBoolean(Object object, long offset, boolean value) { - _UNSAFE.putBoolean(object, offset, value); - } + public static final int DOUBLE_ARRAY_OFFSET; - public static byte getByte(Object object, long offset) { - return _UNSAFE.getByte(object, offset); - } + // Support for resetting final fields while deserializing + public static final long BIG_INTEGER_SIGNUM_OFFSET; + public static final long BIG_INTEGER_MAG_OFFSET; - public static void putByte(Object object, long offset, byte value) { - _UNSAFE.putByte(object, offset, value); - } + public static int getInt(Object object, long offset) { + return _UNSAFE.getInt(object, offset); + } - public static short getShort(Object object, long offset) { - return _UNSAFE.getShort(object, offset); - } + public static void putInt(Object object, long offset, int value) { + _UNSAFE.putInt(object, offset, value); + } - public static void putShort(Object object, long offset, short value) { - _UNSAFE.putShort(object, offset, value); - } + public static boolean getBoolean(Object object, long offset) { + return _UNSAFE.getBoolean(object, offset); + } - public static long getLong(Object object, long offset) { - return _UNSAFE.getLong(object, offset); - } + public static void putBoolean(Object object, long offset, boolean value) { + _UNSAFE.putBoolean(object, offset, value); + } - public static void putLong(Object object, long offset, long value) { - _UNSAFE.putLong(object, offset, value); - } + public static byte getByte(Object object, long offset) { + return _UNSAFE.getByte(object, offset); + } - public static float getFloat(Object object, long offset) { - return _UNSAFE.getFloat(object, offset); - } + public static void putByte(Object object, long offset, byte value) { + _UNSAFE.putByte(object, offset, value); + } - public static void putFloat(Object object, long offset, float value) { - _UNSAFE.putFloat(object, offset, value); - } + public static short getShort(Object object, long offset) { + return _UNSAFE.getShort(object, offset); + } - public static double getDouble(Object object, long offset) { - return _UNSAFE.getDouble(object, offset); - } + public static void putShort(Object object, long offset, short value) { + _UNSAFE.putShort(object, offset, value); + } - public static void putDouble(Object object, long offset, double value) { - _UNSAFE.putDouble(object, offset, value); - } + public static long getLong(Object object, long offset) { + return _UNSAFE.getLong(object, offset); + } - public static Object getObjectVolatile(Object object, long offset) { - return _UNSAFE.getObjectVolatile(object, offset); - } + public static void putLong(Object object, long offset, long value) { + _UNSAFE.putLong(object, offset, value); + } - public static void putObjectVolatile(Object object, long offset, Object value) { - _UNSAFE.putObjectVolatile(object, offset, value); - } + public static float getFloat(Object object, long offset) { + return _UNSAFE.getFloat(object, offset); + } - public static long allocateMemory(long size) { - return _UNSAFE.allocateMemory(size); - } + public static void putFloat(Object object, long offset, float value) { + _UNSAFE.putFloat(object, offset, value); + } - public static void freeMemory(long address) { - _UNSAFE.freeMemory(address); - } + public static double getDouble(Object object, long offset) { + return _UNSAFE.getDouble(object, offset); + } + public static void putDouble(Object object, long offset, double value) { + _UNSAFE.putDouble(object, offset, value); } - private static final Unsafe _UNSAFE; + public static Object getObjectVolatile(Object object, long offset) { + return _UNSAFE.getObjectVolatile(object, offset); + } - public static final int BYTE_ARRAY_OFFSET; + public static void putObjectVolatile(Object object, long offset, Object value) { + _UNSAFE.putObjectVolatile(object, offset, value); + } - public static final int INT_ARRAY_OFFSET; + public static long allocateMemory(long size) { + return _UNSAFE.allocateMemory(size); + } - public static final int LONG_ARRAY_OFFSET; + public static void freeMemory(long address) { + _UNSAFE.freeMemory(address); + } - public static final int DOUBLE_ARRAY_OFFSET; + public static void copyMemory( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } - // Support for resetting final fields while deserializing - public static final long BIG_INTEGER_SIGNUM_OFFSET; - public static final long BIG_INTEGER_MAG_OFFSET; + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + _UNSAFE.throwException(t); + } /** * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to @@ -162,26 +170,4 @@ public static void freeMemory(long address) { BIG_INTEGER_MAG_OFFSET = 0; } } - - static public void copyMemory( - Object src, - long srcOffset, - Object dst, - long dstOffset, - long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; - } - } - - /** - * Raises an exception bypassing compiler checks for checked exceptions. - */ - public static void throwException(Throwable t) { - _UNSAFE.throwException(t); - } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 70b81ce015dd..cf42877bf9fd 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import static org.apache.spark.unsafe.PlatformDependent.*; +import org.apache.spark.unsafe.Platform; public class ByteArrayMethods { @@ -45,20 +45,18 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, - long leftOffset, - Object rightBase, - long rightOffset, - final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; while (i <= length - 8) { - if (UNSAFE.getLong(leftBase, leftOffset + i) != UNSAFE.getLong(rightBase, rightOffset + i)) { + if (Platform.getLong(leftBase, leftOffset + i) != + Platform.getLong(rightBase, rightOffset + i)) { return false; } i += 8; } while (i < length) { - if (UNSAFE.getByte(leftBase, leftOffset + i) != UNSAFE.getByte(rightBase, rightOffset + i)) { + if (Platform.getByte(leftBase, leftOffset + i) != + Platform.getByte(rightBase, rightOffset + i)) { return false; } i += 1; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 18d1f0d2d7eb..74105050e419 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -64,7 +64,7 @@ public long size() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + Platform.putLong(baseObj, baseOffset + index * WIDTH, value); } /** @@ -73,6 +73,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + return Platform.getLong(baseObj, baseOffset + index * WIDTH); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 27462c7fa5e6..7857bf66a72a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.bitset; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * Methods for working with fixed-size uncompressed bitsets. @@ -41,8 +41,8 @@ public static void set(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word | mask); } /** @@ -52,8 +52,8 @@ public static void unset(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); - PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask); + final long word = Platform.getLong(baseObject, wordOffset); + Platform.putLong(baseObject, wordOffset, word & ~mask); } /** @@ -63,7 +63,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { assert index >= 0 : "index (" + index + ") should >= 0"; final long mask = 1L << (index & 0x3f); // mod 64 and shift final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; - final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + final long word = Platform.getLong(baseObject, wordOffset); return (word & mask) != 0; } @@ -73,7 +73,7 @@ public static boolean isSet(Object baseObject, long baseOffset, int index) { public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInWords) { long addr = baseOffset; for (int i = 0; i < bitSetWidthInWords; i++, addr += WORD_SIZE) { - if (PlatformDependent.UNSAFE.getLong(baseObject, addr) != 0) { + if (Platform.getLong(baseObject, addr) != 0) { return true; } } @@ -109,8 +109,7 @@ public static int nextSetBit( // Try to find the next set bit in the current word final int subIndex = fromIndex & 0x3f; - long word = - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + long word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; if (word != 0) { return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); } @@ -118,7 +117,7 @@ public static int nextSetBit( // Find the next set bit in the rest of the words wi += 1; while (wi < bitsetSizeInWords) { - word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE); + word = Platform.getLong(baseObject, baseOffset + wi * WORD_SIZE); if (word != 0) { return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 61f483ced321..4276f25c2165 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -53,7 +53,7 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); + int halfWord = Platform.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index 91be46ba21ff..dd7582083437 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -19,7 +19,7 @@ import javax.annotation.Nullable; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. @@ -50,6 +50,6 @@ public long size() { * Creates a memory block pointing to the memory used by the long array. */ public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8); + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 62f4459696c2..cda7826c8c99 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.memory; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; /** * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. @@ -29,7 +29,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { if (size % 8 != 0) { throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); } - long address = PlatformDependent.UNSAFE.allocateMemory(size); + long address = Platform.allocateMemory(size); return new MemoryBlock(null, address, size); } @@ -37,6 +37,6 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.obj == null) : "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - PlatformDependent.UNSAFE.freeMemory(memory.offset); + Platform.freeMemory(memory.offset); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 69b0e206cef1..c08c9c73d239 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,7 +17,7 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; public class ByteArray { @@ -27,12 +27,6 @@ public class ByteArray { * hold all the bytes in this string. */ public static void writeToMemory(byte[] src, Object target, long targetOffset) { - PlatformDependent.copyMemory( - src, - PlatformDependent.BYTE_ARRAY_OFFSET, - target, - targetOffset, - src.length - ); + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET, target, targetOffset, src.length); } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index d1014426c0f4..667c00900f2c 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -24,10 +24,10 @@ import java.util.Arrays; import java.util.Map; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import static org.apache.spark.unsafe.PlatformDependent.*; +import static org.apache.spark.unsafe.Platform.*; /** @@ -133,13 +133,7 @@ protected UTF8String(Object base, long offset, int numBytes) { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - PlatformDependent.copyMemory( - base, - offset, - target, - targetOffset, - numBytes - ); + Platform.copyMemory(base, offset, target, targetOffset, numBytes); } /** @@ -183,12 +177,12 @@ public long getPrefix() { long mask = 0; if (isLittleEndian) { if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + p = (long) Platform.getInt(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -197,12 +191,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); + p = Platform.getLong(base, offset); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) PlatformDependent.UNSAFE.getInt(base, offset)) << 32; + p = ((long) Platform.getInt(base, offset)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -293,7 +287,7 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return UNSAFE.getByte(base, offset + i); + return Platform.getByte(base, offset + i); } private boolean matchAt(final UTF8String s, int pos) { @@ -769,7 +763,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { int len = inputs[i].numBytes; copyMemory( inputs[i].base, inputs[i].offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -778,7 +772,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { if (j < numInputs) { copyMemory( separator.base, separator.offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 3b9175835229..2f8cb132ac8b 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,7 +22,7 @@ import java.util.Set; import junit.framework.Assert; -import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.Platform; import org.junit.Test; /** @@ -83,11 +83,11 @@ public void randomizedStressTestBytes() { rand.nextBytes(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. @@ -106,11 +106,11 @@ public void randomizedStressTestPaddedStrings() { System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); } // A very loose bound. From dfe347d2cae3eb05d7539aaf72db3d309e711213 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 11 Aug 2015 08:52:15 -0700 Subject: [PATCH 0161/1824] [SPARK-9785] [SQL] HashPartitioning compatibility should consider expression ordering HashPartitioning compatibility is currently defined w.r.t the _set_ of expressions, but the ordering of those expressions matters when computing hash codes; this could lead to incorrect answers if we mistakenly avoided a shuffle based on the assumption that HashPartitionings with the same expressions in different orders will produce equivalent row hashcodes. The first commit adds a regression test which illustrates this problem. The fix for this is simple: make `HashPartitioning.compatibleWith` and `HashPartitioning.guarantees` sensitive to the expression ordering (i.e. do not perform set comparison). Author: Josh Rosen Closes #8074 from JoshRosen/hashpartitioning-compatiblewith-fixes and squashes the following commits: b61412f [Josh Rosen] Demonstrate that I haven't cheated in my fix 0b4d7d9 [Josh Rosen] Update so that clusteringSet is only used in satisfies(). dc9c9d7 [Josh Rosen] Add failing regression test for SPARK-9785 --- .../plans/physical/partitioning.scala | 15 ++--- .../sql/catalyst/PartitioningSuite.scala | 55 +++++++++++++++++++ 2 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 5a89a90b735a..5ac3f1f5b0ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -216,26 +216,23 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - lazy val clusteringSet = expressions.toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + expressions.toSet.subsetOf(requiredClustering.toSet) case _ => false } override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => - this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case o: HashPartitioning => this == o case _ => false } + } /** @@ -257,15 +254,13 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = ordering.map(_.child).toSet - override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) case ClusteredDistribution(requiredClustering) => - clusteringSet.subsetOf(requiredClustering.toSet) + ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala new file mode 100644 index 000000000000..5b802ccc637d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} + +class PartitioningSuite extends SparkFunSuite { + test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { + val expressions = Seq(Literal(2), Literal(3)) + // Consider two HashPartitionings that have the same _set_ of hash expressions but which are + // created with different orderings of those expressions: + val partitioningA = HashPartitioning(expressions, 100) + val partitioningB = HashPartitioning(expressions.reverse, 100) + // These partitionings are not considered equal: + assert(partitioningA != partitioningB) + // However, they both satisfy the same clustered distribution: + val distribution = ClusteredDistribution(expressions) + assert(partitioningA.satisfies(distribution)) + assert(partitioningB.satisfies(distribution)) + // These partitionings compute different hashcodes for the same input row: + def computeHashCode(partitioning: HashPartitioning): Int = { + val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) + hashExprProj.apply(InternalRow.empty).hashCode() + } + assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) + // Thus, these partitionings are incompatible: + assert(!partitioningA.compatibleWith(partitioningB)) + assert(!partitioningB.compatibleWith(partitioningA)) + assert(!partitioningA.guarantees(partitioningB)) + assert(!partitioningB.guarantees(partitioningA)) + + // Just to be sure that we haven't cheated by having these methods always return false, + // check that identical partitionings are still compatible with and guarantee each other: + assert(partitioningA === partitioningA) + assert(partitioningA.guarantees(partitioningA)) + assert(partitioningA.compatibleWith(partitioningA)) + } +} From bce72797f3499f14455722600b0d0898d4fd87c9 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 11 Aug 2015 10:42:17 -0700 Subject: [PATCH 0162/1824] Fix comment error API is updated but its doc comment is not updated. Author: Jeff Zhang Closes #8097 from zjffdu/dev. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ced44131b0d..6aafb4c5644d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -866,7 +866,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * Do - * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * `val rdd = sparkContext.binaryFiles("hdfs://a-hdfs-path")`, * * then `rdd` contains * {{{ From 8cad854ef6a2066de5adffcca6b79a205ccfd5f3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 11 Aug 2015 11:01:59 -0700 Subject: [PATCH 0163/1824] [SPARK-8345] [ML] Add an SQL node as a feature transformer Implements the transforms which are defined by SQL statement. Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' where '__THIS__' represents the underlying table of the input dataset. Author: Yanbo Liang Closes #7465 from yanboliang/spark-8345 and squashes the following commits: b403fcb [Yanbo Liang] address comments 0d4bb15 [Yanbo Liang] a better transformSchema() implementation 51eb9e7 [Yanbo Liang] Add an SQL node as a feature transformer --- .../spark/ml/feature/SQLTransformer.scala | 72 +++++++++++++++++++ .../ml/feature/SQLTransformerSuite.scala | 44 ++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala new file mode 100644 index 000000000000..95e430563873 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Implements the transforms which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * where '__THIS__' represents the underlying table of the input dataset. + */ +@Experimental +class SQLTransformer (override val uid: String) extends Transformer { + + def this() = this(Identifiable.randomUID("sql")) + + /** + * SQL statement parameter. The statement is provided in string form. + * @group param + */ + final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") + + /** @group setParam */ + def setStatement(value: String): this.type = set(statement, value) + + /** @group getParam */ + def getStatement: String = $(statement) + + private val tableIdentifier: String = "__THIS__" + + override def transform(dataset: DataFrame): DataFrame = { + val tableName = Identifiable.randomUID(uid) + dataset.registerTempTable(tableName) + val realStatement = $(statement).replace(tableIdentifier, tableName) + val outputDF = dataset.sqlContext.sql(realStatement) + outputDF + } + + override def transformSchema(schema: StructType): StructType = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + val dummyRDD = sc.parallelize(Seq(Row.empty)) + val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) + dummyDF.registerTempTable(tableIdentifier) + val outputSchema = sqlContext.sql($(statement)).schema + outputSchema + } + + override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala new file mode 100644 index 000000000000..d19052881ae4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new SQLTransformer()) + } + + test("transform numeric data") { + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + val result = sqlTrans.transform(original) + val resultSchema = sqlTrans.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + .toDF("id", "v1", "v2", "v3", "v4") + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } +} From dbd778d84d094ca142bc08c351478595b280bc2a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 11 Aug 2015 11:33:36 -0700 Subject: [PATCH 0164/1824] [SPARK-8764] [ML] string indexer should take option to handle unseen values As a precursor to adding a public constructor add an option to handle unseen values by skipping rather than throwing an exception (default remains throwing an exception), Author: Holden Karau Closes #7266 from holdenk/SPARK-8764-string-indexer-should-take-option-to-handle-unseen-values and squashes the following commits: 38a4de9 [Holden Karau] fix long line 045bf22 [Holden Karau] Add a second b entry so b gets 0 for sure 81dd312 [Holden Karau] Update the docs for handleInvalid param to be more descriptive 7f37f6e [Holden Karau] remove extra space (scala style) 414e249 [Holden Karau] And switch to using handleInvalid instead of skipInvalid 1e53f9b [Holden Karau] update the param (codegen side) 7a22215 [Holden Karau] fix typo 100a39b [Holden Karau] Merge in master aa5b093 [Holden Karau] Since we filter we should never go down this code path if getSkipInvalid is true 75ffa69 [Holden Karau] Remove extra newline d69ef5e [Holden Karau] Add a test b5734be [Holden Karau] Add support for unseen labels afecd4e [Holden Karau] Add a param to skip invalid entries. --- .../spark/ml/feature/StringIndexer.scala | 26 ++++++++++++--- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +++ .../spark/ml/param/shared/sharedParams.scala | 15 +++++++++ .../spark/ml/feature/StringIndexerSuite.scala | 32 +++++++++++++++++++ 4 files changed, 73 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index ebfa97253235..e4485eb03840 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -33,7 +33,8 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol + with HasHandleInvalid { /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -65,13 +66,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def this() = this(Identifiable.randomUID("strIdx")) + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - // TODO: handle unseen labels override def fit(dataset: DataFrame): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) @@ -111,6 +115,10 @@ class StringIndexerModel private[ml] ( map } + /** @group setParam */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -128,14 +136,24 @@ class StringIndexerModel private[ml] ( if (labelToIndex.contains(label)) { labelToIndex(label) } else { - // TODO: handle unseen labels throw new SparkException(s"Unseen label: $label.") } } + val outputColName = $(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), + // If we are skipping invalid records, filter them out. + val filteredDataset = (getHandleInvalid) match { + case "skip" => { + val filterer = udf { label: String => + labelToIndex.contains(label) + } + dataset.where(filterer(dataset($(inputCol)))) + } + case _ => dataset + } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a97c8059b8d4..da4c07683039 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -59,6 +59,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + + "will filter out rows with bad values), or error (which will throw an errror). More " + + "options may be added later.", + isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + " before fitting the model.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index f332630c32f1..23e2b6cc4399 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -247,6 +247,21 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * Trait for shared param handleInvalid. + */ +private[ml] trait HasHandleInvalid extends Params { + + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.. + * @group param + */ + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error"))) + + /** @group getParam */ + final def getHandleInvalid: String = $(handleInvalid) +} + /** * Trait for shared param standardization (default: true). */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d0295a0fe2fc..b111036087e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -62,6 +63,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } + test("StringIndexerUnseen") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) + val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + // Verify we throw by default with unseen values + intercept[SparkException] { + indexer.transform(df2).collect() + } + val indexerSkipInvalid = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .setHandleInvalid("skip") + .fit(df) + // Verify that we skip the c record + val transformed = indexerSkipInvalid.transform(df2) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("b", "a")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expected = Set((0, 1.0), (1, 0.0)) + assert(output === expected) + } + test("StringIndexer with a numeric input column") { val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") From 5b8bb1b213b8738f563fcd00747604410fbb3087 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Aug 2015 12:02:28 -0700 Subject: [PATCH 0165/1824] [SPARK-9572] [STREAMING] [PYSPARK] Added StreamingContext.getActiveOrCreate() in Python Author: Tathagata Das Closes #8080 from tdas/SPARK-9572 and squashes the following commits: 64a231d [Tathagata Das] Fix based on comments 741a0d0 [Tathagata Das] Fixed style f4f094c [Tathagata Das] Tweaked test 9afcdbe [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572 e21488d [Tathagata Das] Minor update 1a371d9 [Tathagata Das] Addressed comments. 60479da [Tathagata Das] Fixed indent 9c2da9c [Tathagata Das] Fixed bugs b5bd32c [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572 b55b348 [Tathagata Das] Removed prints 5781728 [Tathagata Das] Fix style issues b711214 [Tathagata Das] Reverted run-tests.py 643b59d [Tathagata Das] Revert unnecessary change 150e58c [Tathagata Das] Added StreamingContext.getActiveOrCreate() in Python --- python/pyspark/streaming/context.py | 57 +++++++++++- python/pyspark/streaming/tests.py | 133 +++++++++++++++++++++++++--- python/run-tests.py | 2 +- 3 files changed, 177 insertions(+), 15 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index ac5ba69e8dbb..e3ba70e4e5e8 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -86,6 +86,9 @@ class StreamingContext(object): """ _transformerSerializer = None + # Reference to a currently active StreamingContext + _activeContext = None + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @@ -142,10 +145,10 @@ def getOrCreate(cls, checkpointPath, setupFunc): Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be recreated from the checkpoint data. If the data does not exist, then the provided setupFunc - will be used to create a JavaStreamingContext. + will be used to create a new context. - @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program - @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + @param checkpointPath: Checkpoint directory used in an earlier streaming program + @param setupFunc: Function to create a new context and setup DStreams """ # TODO: support checkpoint in HDFS if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): @@ -170,6 +173,52 @@ def getOrCreate(cls, checkpointPath, setupFunc): cls._transformerSerializer.ctx = sc return StreamingContext(sc, None, jssc) + @classmethod + def getActive(cls): + """ + Return either the currently active StreamingContext (i.e., if there is a context started + but not stopped) or None. + """ + activePythonContext = cls._activeContext + if activePythonContext is not None: + # Verify that the current running Java StreamingContext is active and is the same one + # backing the supposedly active Python context + activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode() + activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive() + + if activeJvmContextOption.isEmpty(): + cls._activeContext = None + elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId: + cls._activeContext = None + raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext " + "backing the action Python StreamingContext. This is unexpected.") + return cls._activeContext + + @classmethod + def getActiveOrCreate(cls, checkpointPath, setupFunc): + """ + Either return the active StreamingContext (i.e. currently started but not stopped), + or recreate a StreamingContext from checkpoint data or create a new StreamingContext + using the provided setupFunc function. If the checkpointPath is None or does not contain + valid checkpoint data, then setupFunc will be called to create a new context and setup + DStreams. + + @param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be + None if the intention is to always create a new context when there + is no active context. + @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + """ + + if setupFunc is None: + raise Exception("setupFunc cannot be None") + activeContext = cls.getActive() + if activeContext is not None: + return activeContext + elif checkpointPath is not None: + return cls.getOrCreate(checkpointPath, setupFunc) + else: + return setupFunc() + @property def sparkContext(self): """ @@ -182,6 +231,7 @@ def start(self): Start the execution of the streams. """ self._jssc.start() + StreamingContext._activeContext = self def awaitTermination(self, timeout=None): """ @@ -212,6 +262,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): of all received data to be completed """ self._jssc.stop(stopSparkContext, stopGraceFully) + StreamingContext._activeContext = None if stopSparkContext: self._sc.stop() diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index f0ed415f9712..6108c845c1ef 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -24,6 +24,7 @@ import tempfile import random import struct +import shutil from functools import reduce if sys.version_info[:2] <= (2, 6): @@ -59,12 +60,21 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls.sc.stop() + # Clean up in the JVM just in case there has been some issues in Python API + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop(False) + if self.ssc is not None: + self.ssc.stop(False) + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop(False) def wait_for(self, result, n): start_time = time.time() @@ -442,6 +452,7 @@ def test_reduce_by_invalid_window(self): class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 + setupCalled = False def _add_input_stream(self): inputs = [range(1, x) for x in range(101)] @@ -515,10 +526,85 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_get_active(self): + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that getActive() returns the active context + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + + # Verify that getActive() returns None + self.ssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.assertEqual(StreamingContext.getActive(), None) + + def test_get_active_or_create(self): + # Test StreamingContext.getActiveOrCreate() without checkpoint data + # See CheckpointTests for tests with checkpoint data + self.ssc = None + self.assertEqual(StreamingContext.getActive(), None) + + def setupFunc(): + ssc = StreamingContext(self.sc, self.duration) + ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.setupCalled = True + return ssc + + # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that getActiveOrCreate() retuns active context and does not call the setupFunc + self.ssc.start() + self.setupCalled = False + self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setupFunc after active context is stopped + self.ssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + + # Verify that if the Java context is stopped, then getActive() returns None + self.ssc = StreamingContext(self.sc, self.duration) + self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count()) + self.ssc.start() + self.assertEqual(StreamingContext.getActive(), self.ssc) + self.ssc._jssc.stop(False) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc) + self.assertTrue(self.setupCalled) + class CheckpointTests(unittest.TestCase): - def test_get_or_create(self): + setupCalled = False + + @staticmethod + def tearDownClass(): + # Clean up in the JVM just in case there has been some issues in Python API + jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive() + if jStreamingContextOption.nonEmpty(): + jStreamingContextOption.get().stop() + jSparkContextOption = SparkContext._jvm.SparkContext.get() + if jSparkContextOption.nonEmpty(): + jSparkContextOption.get().stop() + + def tearDown(self): + if self.ssc is not None: + self.ssc.stop(True) + + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" @@ -533,11 +619,12 @@ def setup(): wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") wc.checkpoint(.5) + self.setupCalled = True return ssc cpd = tempfile.mkdtemp("test_streaming_cps") - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.ssc = StreamingContext.getOrCreate(cpd, setup) + self.ssc.start() def check_output(n): while not os.listdir(outputd): @@ -552,7 +639,7 @@ def check_output(n): # not finished time.sleep(0.01) continue - ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: time.sleep(0.01) @@ -568,13 +655,37 @@ def check_output(n): check_output(1) check_output(2) - ssc.stop(True, True) + # Verify the getOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) time.sleep(1) - ssc = StreamingContext.getOrCreate(cpd, setup) - ssc.start() + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() check_output(3) - ssc.stop(True, True) + + # Verify the getActiveOrCreate() recovers from checkpoint files + self.ssc.stop(True, True) + time.sleep(1) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() + check_output(4) + + # Verify that getActiveOrCreate() returns active context + self.setupCalled = False + self.assertEquals(StreamingContext.getActiveOrCreate(cpd, setup), self.ssc) + self.assertFalse(self.setupCalled) + + # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files + self.ssc.stop(True, True) + shutil.rmtree(cpd) # delete checkpoint directory + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.assertTrue(self.setupCalled) + self.ssc.stop(True, True) class KafkaStreamTests(PySparkStreamingTestCase): @@ -1134,7 +1245,7 @@ def search_kinesis_asl_assembly_jar(): testcases.append(KinesisStreamTests) elif are_kinesis_tests_enabled is False: sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " - "not compiled with -Pkinesis-asl profile. To run these tests, " + "not compiled into a JAR. To run these tests, " "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " "streaming-kinesis-asl-assembly/assembly' or " "'build/mvn -Pkinesis-asl package' before running this test.") @@ -1150,4 +1261,4 @@ def search_kinesis_asl_assembly_jar(): for testcase in testcases: sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) - unittest.TextTestRunner(verbosity=2).run(tests) + unittest.TextTestRunner(verbosity=3).run(tests) diff --git a/python/run-tests.py b/python/run-tests.py index cc560779373b..fd56c7ab6e0e 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -158,7 +158,7 @@ def main(): else: log_level = logging.INFO logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") - LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) + LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') From 5831294a7a8fa2524133c5d718cbc8187d2b0620 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 11 Aug 2015 12:39:13 -0700 Subject: [PATCH 0166/1824] [SPARK-9646] [SQL] Add metrics for all join and aggregate operators This PR added metrics for all join and aggregate operators. However, I found the metrics may be confusing in the following two case: 1. The iterator is not totally consumed and the metric values will be less. 2. Recreating the iterators will make metric values look bigger than the size of the input source, such as `CartesianProduct`. Author: zsxwing Closes #8060 from zsxwing/sql-metrics and squashes the following commits: 40f3fc1 [zsxwing] Mark LongSQLMetric private[metric] to avoid using incorrectly and leak memory b1b9071 [zsxwing] Merge branch 'master' into sql-metrics 4bef25a [zsxwing] Add metrics for SortMergeOuterJoin 95ccfc6 [zsxwing] Merge branch 'master' into sql-metrics 67cb4dd [zsxwing] Add metrics for Project and TungstenProject; remove metrics from PhysicalRDD and LocalTableScan 0eb47d4 [zsxwing] Merge branch 'master' into sql-metrics dd9d932 [zsxwing] Avoid creating new Iterators 589ea26 [zsxwing] Add metrics for all join and aggregate operators --- .../spark/sql/execution/Aggregate.scala | 11 + .../spark/sql/execution/ExistingRDD.scala | 2 - .../spark/sql/execution/LocalTableScan.scala | 2 - .../spark/sql/execution/SparkPlan.scala | 25 +- .../aggregate/SortBasedAggregate.scala | 12 +- .../SortBasedAggregationIterator.scala | 18 +- .../aggregate/TungstenAggregate.scala | 12 +- .../TungstenAggregationIterator.scala | 11 +- .../spark/sql/execution/basicOperators.scala | 36 +- .../execution/joins/BroadcastHashJoin.scala | 30 +- .../joins/BroadcastHashOuterJoin.scala | 40 +- .../joins/BroadcastLeftSemiJoinHash.scala | 24 +- .../joins/BroadcastNestedLoopJoin.scala | 27 +- .../execution/joins/CartesianProduct.scala | 25 +- .../spark/sql/execution/joins/HashJoin.scala | 7 +- .../sql/execution/joins/HashOuterJoin.scala | 30 +- .../sql/execution/joins/HashSemiJoin.scala | 23 +- .../sql/execution/joins/HashedRelation.scala | 8 +- .../sql/execution/joins/LeftSemiJoinBNL.scala | 19 +- .../execution/joins/LeftSemiJoinHash.scala | 18 +- .../execution/joins/ShuffledHashJoin.scala | 16 +- .../joins/ShuffledHashOuterJoin.scala | 29 +- .../sql/execution/joins/SortMergeJoin.scala | 21 +- .../execution/joins/SortMergeOuterJoin.scala | 38 +- .../sql/execution/metric/SQLMetrics.scala | 6 + .../execution/joins/HashedRelationSuite.scala | 14 +- .../execution/metric/SQLMetricsSuite.scala | 450 +++++++++++++++++- 27 files changed, 847 insertions(+), 107 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index e8c6a0f8f801..f3b6a3a5f4a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -45,6 +46,10 @@ case class Aggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: List[Distribution] = { if (partial) { UnspecifiedDistribution :: Nil @@ -121,12 +126,15 @@ case class Aggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") if (groupingExpressions.isEmpty) { child.execute().mapPartitions { iter => val buffer = newAggregateBuffer() var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() + numInputRows += 1 var i = 0 while (i < buffer.length) { buffer(i).update(currentRow) @@ -142,6 +150,7 @@ case class Aggregate( i += 1 } + numOutputRows += 1 Iterator(resultProjection(aggregateResults)) } } else { @@ -152,6 +161,7 @@ case class Aggregate( var currentRow: InternalRow = null while (iter.hasNext) { currentRow = iter.next() + numInputRows += 1 val currentGroup = groupingProjection(currentRow) var currentBuffer = hashTable.get(currentGroup) if (currentBuffer == null) { @@ -180,6 +190,7 @@ case class Aggregate( val currentEntry = hashTableIter.next() val currentGroup = currentEntry.getKey val currentBuffer = currentEntry.getValue + numOutputRows += 1 var i = 0 while (i < currentBuffer.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index cae7ca5cbdc8..abb60cf12e3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -99,8 +99,6 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], extraInformation: String) extends LeafNode { - override protected[sql] val trackNumOfRowsEnabled = true - protected override def doExecute(): RDD[InternalRow] = rdd override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index 858dd85fd1fa..34e926e4582b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -30,8 +30,6 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - override protected[sql] val trackNumOfRowsEnabled = true - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) protected override def doExecute(): RDD[InternalRow] = rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 9ba5cf2d2b39..72f5450510a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -80,23 +80,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ super.makeCopy(newArgs) } - /** - * Whether track the number of rows output by this SparkPlan - */ - protected[sql] def trackNumOfRowsEnabled: Boolean = false - - private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] = - if (trackNumOfRowsEnabled) { - Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - } - else { - Map.empty - } - /** * Return all metrics containing metrics of this SparkPlan. */ - private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics + private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty /** * Return a LongSQLMetric according to the name. @@ -150,15 +137,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() - if (trackNumOfRowsEnabled) { - val numRows = longMetric("numRows") - doExecute().map { row => - numRows += 1 - row - } - } else { - doExecute() - } + doExecute() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ad428ad663f3..ab26f9c58aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType case class SortBasedAggregate( @@ -38,6 +39,10 @@ case class SortBasedAggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputsUnsafeRows: Boolean = false override def canProcessUnsafeRows: Boolean = false @@ -63,6 +68,8 @@ case class SortBasedAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. @@ -84,10 +91,13 @@ case class SortBasedAggregate( newProjection _, child.output, iter, - outputsUnsafeRows) + outputsUnsafeRows, + numInputRows, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. + numOutputRows += 1 Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) } else { outputIter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 67ebafde25ad..73d50e07cf0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.KVIterator /** @@ -37,7 +38,9 @@ class SortBasedAggregationIterator( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) + outputsUnsafeRows: Boolean, + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric) extends AggregationIterator( groupingKeyAttributes, valueAttributes, @@ -103,6 +106,7 @@ class SortBasedAggregationIterator( // Get the grouping key. val groupingKey = inputKVIterator.getKey val currentRow = inputKVIterator.getValue + numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { @@ -137,7 +141,7 @@ class SortBasedAggregationIterator( val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer) // Initialize buffer values for the next group. initializeBuffer(sortBasedAggregationBuffer) - + numOutputRows += 1 outputRow } else { // no more result @@ -151,7 +155,7 @@ class SortBasedAggregationIterator( nextGroupingKey = inputKVIterator.getKey().copy() firstRowInNextGroup = inputKVIterator.getValue().copy() - + numInputRows += 1 sortedInputHasNewGroup = true } else { // This inputIter is empty. @@ -181,7 +185,9 @@ object SortBasedAggregationIterator { newProjection: (Seq[Expression], Seq[Attribute]) => Projection, inputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean): SortBasedAggregationIterator = { + outputsUnsafeRows: Boolean, + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric): SortBasedAggregationIterator = { val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { AggregationIterator.unsafeKVIterator( groupingExprs, @@ -202,7 +208,9 @@ object SortBasedAggregationIterator { initialInputBufferOffset, resultExpressions, newMutableProjection, - outputsUnsafeRows) + outputsUnsafeRows, + numInputRows, + numOutputRows) } // scalastyle:on } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 1694794a53d9..6b5935a7ce29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -35,6 +36,10 @@ case class TungstenAggregate( child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true @@ -61,6 +66,8 @@ case class TungstenAggregate( } protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numInputRows = longMetric("numInputRows") + val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitions { iter => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { @@ -78,9 +85,12 @@ case class TungstenAggregate( newMutableProjection, child.output, iter, - testFallbackStartsAt) + testFallbackStartsAt, + numInputRows, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { + numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { aggregationIterator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 32160906c3bc..1f383dd04482 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.StructType /** @@ -83,7 +84,9 @@ class TungstenAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - testFallbackStartsAt: Option[Int]) + testFallbackStartsAt: Option[Int], + numInputRows: LongSQLMetric, + numOutputRows: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// @@ -352,6 +355,7 @@ class TungstenAggregationIterator( private def processInputs(): Unit = { while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) if (buffer == null) { @@ -371,6 +375,7 @@ class TungstenAggregationIterator( var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) val buffer: UnsafeRow = if (i < fallbackStartsAt) { hashMap.getAggregationBufferFromUnsafeRow(groupingKey) @@ -439,6 +444,7 @@ class TungstenAggregationIterator( // Process the rest of input rows. while (inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) buffer.copyFrom(initialAggregationBuffer) processRow(buffer, newInput) @@ -462,6 +468,7 @@ class TungstenAggregationIterator( // Insert the rest of input rows. while (inputIter.hasNext) { val newInput = inputIter.next() + numInputRows += 1 val groupingKey = groupProjection.apply(newInput) bufferExtractor(newInput) externalSorter.insertKV(groupingKey, buffer) @@ -657,7 +664,7 @@ class TungstenAggregationIterator( TaskContext.get().internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory) } - + numOutputRows += 1 res } else { // no more result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index bf2de244c8e4..247c900baae9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -41,11 +41,20 @@ import org.apache.spark.{HashPartitioner, SparkEnv} case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override private[sql] lazy val metrics = Map( + "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - val reusableProjection = buildProjection() - iter.map(reusableProjection) + protected override def doExecute(): RDD[InternalRow] = { + val numRows = longMetric("numRows") + child.execute().mapPartitions { iter => + val reusableProjection = buildProjection() + iter.map { row => + numRows += 1 + reusableProjection(row) + } + } } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -57,19 +66,28 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends */ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + override private[sql] lazy val metrics = Map( + "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = projectList.map(_.toAttribute) - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - this.transformAllExpressions { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + protected override def doExecute(): RDD[InternalRow] = { + val numRows = longMetric("numRows") + child.execute().mapPartitions { iter => + this.transformAllExpressions { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + } + val project = UnsafeProjection.create(projectList, child.output) + iter.map { row => + numRows += 1 + project(row) + } } - val project = UnsafeProjection.create(projectList, child.output) - iter.map(project) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index f7a68e4f5d44..2e108cb81451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils import org.apache.spark.{InternalAccumulator, TaskContext} @@ -45,7 +46,10 @@ case class BroadcastHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override protected[sql] val trackNumOfRowsEnabled = true + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) val timeout: Duration = { val timeoutValue = sqlContext.conf.broadcastTimeout @@ -65,6 +69,11 @@ case class BroadcastHashJoin( // for the same query. @transient private lazy val broadcastFuture = { + val numBuildRows = buildSide match { + case BuildLeft => longMetric("numLeftRows") + case BuildRight => longMetric("numRightRows") + } + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) future { @@ -73,8 +82,15 @@ case class BroadcastHashJoin( SQLExecution.withExecutionId(sparkContext, executionId) { // Note that we use .execute().collect() because we don't want to convert data to Scala // types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size) + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size) sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -85,6 +101,12 @@ case class BroadcastHashJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = buildSide match { + case BuildLeft => longMetric("numRightRows") + case BuildRight => longMetric("numLeftRows") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -95,7 +117,7 @@ case class BroadcastHashJoin( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) case _ => } - hashJoin(streamedIter, hashedRelation) + hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index a3626de49aea..69a8b95eaa7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.{InternalAccumulator, TaskContext} /** @@ -45,6 +46,11 @@ case class BroadcastHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + val timeout = { val timeoutValue = sqlContext.conf.broadcastTimeout if (timeoutValue < 0) { @@ -63,6 +69,14 @@ case class BroadcastHashOuterJoin( // for the same query. @transient private lazy val broadcastFuture = { + val numBuildRows = joinType match { + case RightOuter => longMetric("numLeftRows") + case LeftOuter => longMetric("numRightRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) future { @@ -71,8 +85,15 @@ case class BroadcastHashOuterJoin( SQLExecution.withExecutionId(sparkContext, executionId) { // Note that we use .execute().collect() because we don't want to convert data to Scala // types - val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) + val input: Array[InternalRow] = buildPlan.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect() + // The following line doesn't run in a job so we cannot track the metric value. However, we + // have already tracked it in the above lines. So here we can use + // `SQLMetrics.nullLongMetric` to ignore it. + val hashed = HashedRelation( + input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) } }(BroadcastHashJoin.broadcastHashJoinExecutionContext) @@ -83,6 +104,15 @@ case class BroadcastHashOuterJoin( } override def doExecute(): RDD[InternalRow] = { + val numStreamedRows = joinType match { + case RightOuter => longMetric("numRightRows") + case LeftOuter => longMetric("numLeftRows") + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + val numOutputRows = longMetric("numOutputRows") + val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => @@ -101,16 +131,18 @@ case class BroadcastHashOuterJoin( joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) }) case RightOuter => streamedIter.flatMap(currentRow => { + numStreamedRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) }) case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 5bd06fbdca60..78a8c16c62bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -37,18 +38,31 @@ case class BroadcastLeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val input = right.execute().map(_.copy()).collect() + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val input = right.execute().map { row => + numRightRows += 1 + row.copy() + }.collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator) + val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value) + hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows) } } else { - val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size) + val hashRelation = + HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => @@ -59,7 +73,7 @@ case class BroadcastLeftSemiJoinHash( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize) case _ => } - hashSemiJoin(streamIter, hashedRelation) + hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 017a44b9ca86..28c88b1b03d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.CompactBuffer /** @@ -38,6 +39,11 @@ case class BroadcastNestedLoopJoin( condition: Option[Expression]) extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + /** BuildRight means the right relation <=> the broadcast relation. */ private val (streamed, broadcast) = buildSide match { case BuildRight => (left, right) @@ -75,9 +81,17 @@ case class BroadcastNestedLoopJoin( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val (numStreamedRows, numBuildRows) = buildSide match { + case BuildRight => (longMetric("numLeftRows"), longMetric("numRightRows")) + case BuildLeft => (longMetric("numRightRows"), longMetric("numLeftRows")) + } + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()) - .collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numBuildRows += 1 + row.copy() + }.collect().toIndexedSeq) /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => @@ -94,6 +108,7 @@ case class BroadcastNestedLoopJoin( streamedIter.foreach { streamedRow => var i = 0 var streamRowMatched = false + numStreamedRows += 1 while (i < broadcastedRelation.value.size) { val broadcastedRow = broadcastedRelation.value(i) @@ -162,6 +177,12 @@ case class BroadcastNestedLoopJoin( // TODO: Breaks lineage. sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + matchesOrStreamedRowsWithNulls.flatMap(_._1), + sparkContext.makeRDD(broadcastRowsWithNulls) + ).map { row => + // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here. + numOutputRows += 1 + row + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 261b4724159f..2115f4070228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -30,13 +31,31 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + + val leftResults = left.execute().map { row => + numLeftRows += 1 + row.copy() + } + val rightResults = right.execute().map { row => + numRightRows += 1 + row.copy() + } leftResults.cartesian(rightResults).mapPartitions { iter => val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) + iter.map { r => + numOutputRows += 1 + joinedRow(r._1, r._2) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 22d46d1c3e3b..7ce4a517838c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashJoin { @@ -69,7 +70,9 @@ trait HashJoin { protected def hashJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ @@ -98,6 +101,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 + numOutputRows += 1 resultProjection(ret) } @@ -113,6 +117,7 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() + numStreamRows += 1 val key = joinKeys(currentStreamedRow) if (!key.anyNull) { currentHashMatches = hashedRelation.get(key) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 701bd3cd8637..66903347c88c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.util.collection.CompactBuffer @DeveloperApi @@ -114,22 +115,28 @@ trait HashOuterJoin { key: InternalRow, joinedRow: JoinedRow, rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (rightIter != null) { rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + case r if boundCondition(joinedRow.withRight(r)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { + numOutputRows += 1 resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } else { temp } } else { + numOutputRows += 1 resultProjection(joinedRow.withRight(rightNullRow)) :: Nil } } @@ -140,22 +147,28 @@ trait HashOuterJoin { key: InternalRow, leftIter: Iterable[InternalRow], joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = { + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (leftIter != null) { leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => resultProjection(joinedRow).copy() + case l if boundCondition(joinedRow.withLeft(l)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } } } else { List.empty } if (temp.isEmpty) { + numOutputRows += 1 resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } else { temp } } else { + numOutputRows += 1 resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil } } @@ -164,7 +177,7 @@ trait HashOuterJoin { protected[this] def fullOuterIterator( key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], - joinedRow: JoinedRow): Iterator[InternalRow] = { + joinedRow: JoinedRow, numOutputRows: LongSQLMetric): Iterator[InternalRow] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -177,6 +190,7 @@ trait HashOuterJoin { // append them directly case (r, idx) if boundCondition(joinedRow.withRight(r)) => + numOutputRows += 1 matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) @@ -189,6 +203,7 @@ trait HashOuterJoin { // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. + numOutputRows += 1 joinedRow.withRight(rightNullRow).copy() }) } ++ rightIter.zipWithIndex.collect { @@ -197,12 +212,15 @@ trait HashOuterJoin { // Re-visiting the records in right, and append additional row with empty left, if its not // in the matched set. case (r, idx) if !rightMatchedSet.contains(idx) => + numOutputRows += 1 joinedRow(leftNullRow, r).copy() } } else { leftIter.iterator.map[InternalRow] { l => + numOutputRows += 1 joinedRow(l, rightNullRow).copy() } ++ rightIter.iterator.map[InternalRow] { r => + numOutputRows += 1 joinedRow(leftNullRow, r).copy() } } @@ -211,10 +229,12 @@ trait HashOuterJoin { // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], + numIterRows: LongSQLMetric, keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]() while (iter.hasNext) { val currentRow = iter.next() + numIterRows += 1 val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 82dd6eb7e7ed..beb141ade616 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.LongSQLMetric trait HashSemiJoin { @@ -61,13 +62,15 @@ trait HashSemiJoin { @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { + protected def buildKeyHashSet( + buildIter: Iterator[InternalRow], numBuildRows: LongSQLMetric): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() // Create a Hash set of buildKeys val rightKey = rightKeyGenerator while (buildIter.hasNext) { val currentRow = buildIter.next() + numBuildRows += 1 val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) @@ -82,25 +85,35 @@ trait HashSemiJoin { protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashSet: java.util.Set[InternalRow], + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator streamIter.filter(current => { + numStreamRows += 1 val key = joinKeys(current) - !key.anyNull && hashSet.contains(key) + val r = !key.anyNull && hashSet.contains(key) + if (r) numOutputRows += 1 + r }) } protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { + numStreamRows: LongSQLMetric, + hashedRelation: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow streamIter.filter { current => + numStreamRows += 1 val key = joinKeys(current) lazy val rowBuffer = hashedRelation.get(key) - !key.anyNull && rowBuffer != null && rowBuffer.exists { + val r = !key.anyNull && rowBuffer != null && rowBuffer.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) } + if (r) numOutputRows += 1 + r } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 63d35d0f0262..c1bc7947aa39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,6 +25,7 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -112,11 +113,13 @@ private[joins] object HashedRelation { def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { if (keyGenerator.isInstanceOf[UnsafeProjection]) { - return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + return UnsafeHashedRelation( + input, numInputRows, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) } // TODO: Use Spark's HashMap implementation. @@ -130,6 +133,7 @@ private[joins] object HashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { currentRow = input.next() + numInputRows += 1 val rowKey = keyGenerator(currentRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) @@ -331,6 +335,7 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], + numInputRows: LongSQLMetric, keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { @@ -340,6 +345,7 @@ private[joins] object UnsafeHashedRelation { // Create a mapping of buildKeys -> rows while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] + numInputRows += 1 val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 4443455ef11f..ad6362542f2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -35,6 +36,11 @@ case class LeftSemiJoinBNL( extends BinaryNode { // TODO: Override requiredChildDistribution. + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = left.output @@ -52,13 +58,21 @@ case class LeftSemiJoinBNL( newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map { row => + numRightRows += 1 + row.copy() + }.collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow streamedIter.filter(streamedRow => { + numLeftRows += 1 var i = 0 var matched = false @@ -69,6 +83,9 @@ case class LeftSemiJoinBNL( } i += 1 } + if (matched) { + numOutputRows += 1 + } matched }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 68ccd34d8ed9..18808adaac63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -37,19 +38,28 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) - hashSemiJoin(streamIter, hashSet) + val hashSet = buildKeyHashSet(buildIter, numRightRows) + hashSemiJoin(streamIter, numLeftRows, hashSet, numOutputRows) } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) - hashSemiJoin(streamIter, hashRelation) + val hashRelation = HashedRelation(buildIter, numRightRows, rightKeyGenerator) + hashSemiJoin(streamIter, numLeftRows, hashRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index c923dc837c44..fc8c9439a6f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -38,7 +39,10 @@ case class ShuffledHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override protected[sql] val trackNumOfRowsEnabled = true + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def outputPartitioning: Partitioning = PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) @@ -47,9 +51,15 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { + val (numBuildRows, numStreamedRows) = buildSide match { + case BuildLeft => (longMetric("numLeftRows"), longMetric("numRightRows")) + case BuildRight => (longMetric("numRightRows"), longMetric("numLeftRows")) + } + val numOutputRows = longMetric("numOutputRows") + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) - hashJoin(streamIter, hashed) + val hashed = HashedRelation(buildIter, numBuildRows, buildSideKeyGenerator) + hashJoin(streamIter, numStreamedRows, hashed, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index 6a8c35efca8f..ed282f98b7d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics /** * :: DeveloperApi :: @@ -41,6 +42,11 @@ case class ShuffledHashOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashOuterJoin { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil @@ -53,39 +59,48 @@ case class ShuffledHashOuterJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val hashed = HashedRelation(rightIter, buildKeyGenerator) + val hashed = HashedRelation(rightIter, numRightRows, buildKeyGenerator) val keyGenerator = streamedKeyGenerator val resultProj = resultProjection leftIter.flatMap( currentRow => { + numLeftRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) }) case RightOuter => - val hashed = HashedRelation(leftIter, buildKeyGenerator) + val hashed = HashedRelation(leftIter, numLeftRows, buildKeyGenerator) val keyGenerator = streamedKeyGenerator val resultProj = resultProjection rightIter.flatMap ( currentRow => { + numRightRows += 1 val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) }) case FullOuter => // TODO(davies): use UnsafeRow - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val leftHashTable = + buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)) + val rightHashTable = + buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => fullOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST), - joinedRow) + joinedRow, + numOutputRows) } case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 6d656ea2849a..6b7322671d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} /** * :: DeveloperApi :: @@ -37,6 +38,11 @@ case class SortMergeJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = left.output ++ right.output override def outputPartitioning: Partitioning = @@ -70,6 +76,10 @@ case class SortMergeJoin( } protected override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { // An ordering that can be used to compare keys from both sides. @@ -82,7 +92,9 @@ case class SortMergeJoin( rightKeyGenerator, keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + numLeftRows, + RowIterator.fromScala(rightIter), + numRightRows ) private[this] val joinRow = new JoinedRow private[this] val resultProjection: (InternalRow) => InternalRow = { @@ -108,6 +120,7 @@ case class SortMergeJoin( if (currentLeftRow != null) { joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) currentMatchIdx += 1 + numOutputRows += 1 true } else { false @@ -144,7 +157,9 @@ private[joins] class SortMergeJoinScanner( bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, - bufferedIter: RowIterator) { + numStreamedRows: LongSQLMetric, + bufferedIter: RowIterator, + numBufferedRows: LongSQLMetric) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -269,6 +284,7 @@ private[joins] class SortMergeJoinScanner( if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) + numStreamedRows += 1 true } else { streamedRow = null @@ -286,6 +302,7 @@ private[joins] class SortMergeJoinScanner( while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) + numBufferedRows += 1 foundRow = !bufferedRowKey.anyNull } if (!foundRow) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 5326966b07a6..dea9e5e580a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} /** * :: DeveloperApi :: @@ -40,6 +41,11 @@ case class SortMergeOuterJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode { + override private[sql] lazy val metrics = Map( + "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), + "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def output: Seq[Attribute] = { joinType match { case LeftOuter => @@ -108,6 +114,10 @@ case class SortMergeOuterJoin( } override def doExecute(): RDD[InternalRow] = { + val numLeftRows = longMetric("numLeftRows") + val numRightRows = longMetric("numRightRows") + val numOutputRows = longMetric("numOutputRows") + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) @@ -133,10 +143,13 @@ case class SortMergeOuterJoin( bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), - bufferedIter = RowIterator.fromScala(rightIter) + numLeftRows, + bufferedIter = RowIterator.fromScala(rightIter), + numRightRows ) val rightNullRow = new GenericInternalRow(right.output.length) - new LeftOuterIterator(smjScanner, rightNullRow, boundCondition, resultProj).toScala + new LeftOuterIterator( + smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala case RightOuter => val smjScanner = new SortMergeJoinScanner( @@ -144,10 +157,13 @@ case class SortMergeOuterJoin( bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), - bufferedIter = RowIterator.fromScala(leftIter) + numRightRows, + bufferedIter = RowIterator.fromScala(leftIter), + numLeftRows ) val leftNullRow = new GenericInternalRow(left.output.length) - new RightOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj).toScala + new RightOuterIterator( + smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala case x => throw new IllegalArgumentException( @@ -162,7 +178,8 @@ private class LeftOuterIterator( smjScanner: SortMergeJoinScanner, rightNullRow: InternalRow, boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric ) extends RowIterator { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var rightIdx: Int = 0 @@ -198,7 +215,9 @@ private class LeftOuterIterator( } override def advanceNext(): Boolean = { - advanceRightUntilBoundConditionSatisfied() || advanceLeft() + val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft() + if (r) numRows += 1 + r } override def getRow: InternalRow = resultProj(joinedRow) @@ -208,7 +227,8 @@ private class RightOuterIterator( smjScanner: SortMergeJoinScanner, leftNullRow: InternalRow, boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric ) extends RowIterator { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var leftIdx: Int = 0 @@ -244,7 +264,9 @@ private class RightOuterIterator( } override def advanceNext(): Boolean = { - advanceLeftUntilBoundConditionSatisfied() || advanceRight() + val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight() + if (r) numRows += 1 + r } override def getRow: InternalRow = resultProj(joinedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1b51a5e5c8a8..7a2a98ec18cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -112,4 +112,10 @@ private[sql] object SQLMetrics { sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } + + /** + * A metric that its value will be ignored. Use this one when we need a metric parameter but don't + * care about the value. + */ + val nullLongMetric = new LongSQLMetric("null") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b1a9b21a96b..a1fa2c3864bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -22,6 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -35,7 +37,8 @@ class HashedRelationSuite extends SparkFunSuite { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -45,11 +48,13 @@ class HashedRelationSuite extends SparkFunSuite { val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) assert(hashed.get(data(2)) === data2) + assert(numDataRows.value.value === data.length) } test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) + val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) @@ -62,17 +67,19 @@ class HashedRelationSuite extends SparkFunSuite { assert(uniqHashed.getValue(data(1)) === data(1)) assert(uniqHashed.getValue(data(2)) === data(2)) assert(uniqHashed.getValue(InternalRow(10)) === null) + assert(numDataRows.value.value === data.length) } test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray val buildKey = Seq(BoundReference(0, IntegerType, false)) val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, numDataRows, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) @@ -94,5 +101,6 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) + assert(numDataRows.value.value === data.length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 953284c98b20..7383d3f8fe02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -25,15 +25,24 @@ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.util.Utils +class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { -class SQLMetricsSuite extends SparkFunSuite { + override val sqlContext = TestSQLContext + + import sqlContext.implicits._ test("LongSQLMetric should not box Long") { val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") - val f = () => { l += 1L } + val f = () => { + l += 1L + l.add(1L) + } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() cl.accept(boxingFinder, 0) @@ -51,6 +60,441 @@ class SQLMetricsSuite extends SparkFunSuite { assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") } } + + /** + * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". + * + * @param df `DataFrame` to run + * @param expectedNumOfJobs number of jobs that will run + * @param expectedMetrics the expected metrics. The format is + * `nodeId -> (operatorName, metric name -> metric value)`. + */ + private def testSparkPlanMetrics( + df: DataFrame, + expectedNumOfJobs: Int, + expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { + val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + df.collect() + TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= expectedNumOfJobs) + if (jobs.size == expectedNumOfJobs) { + // If we can track all jobs, check the metric values + val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + expectedMetrics.contains(node.id) + }.map { node => + val nodeMetrics = node.metrics.map { metric => + val metricValue = metricValues(metric.accumulatorId) + (metric.name, metricValue) + }.toMap + (node.id, node.name -> nodeMetrics) + }.toMap + assert(expectedMetrics === actualMetrics) + } else { + // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. + // Since we cannot track all jobs, the metric values could be wrong and we should not check + // them. + logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values") + } + } + + test("Project metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) + val df = TestData.person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("Project", Map( + "number of rows" -> 2L))) + ) + } + } + + test("TungstenProject metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = TestData.person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("TungstenProject", Map( + "number of rows" -> 2L))) + ) + } + } + + test("Filter metrics") { + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) + val df = TestData.person.filter('age < 25) + testSparkPlanMetrics(df, 1, Map( + 0L -> ("Filter", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + } + + test("Aggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "false", + SQLConf.TUNGSTEN_ENABLED.key -> "false") { + // Assume the execution plan is + // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) + val df = TestData.testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = TestData.testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("Aggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("Aggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortBasedAggregate metrics") { + // Because SortBasedAggregate may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "false", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> + // SortBasedAggregate(nodeId = 0) + val df = TestData.testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // Assume the execution plan is + // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) + // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) + // 2 partitions and each partition contains 2 keys + val df2 = TestData.testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 3L -> ("SortBasedAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("SortBasedAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("TungstenAggregate metrics") { + withSQLConf( + SQLConf.UNSAFE_ENABLED.key -> "true", + SQLConf.CODEGEN_ENABLED.key -> "true", + SQLConf.TUNGSTEN_ENABLED.key -> "true") { + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = TestData.testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = TestData.testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) + } + } + + test("SortMergeJoin metrics") { + // Because SortMergeJoin may skip different rows if the number of partitions is different, this + // test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 4L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("SortMergeOuterJoin metrics") { + // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, + // this test should use the deterministic number of partitions. + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 8L))) + ) + + val df2 = sqlContext.sql( + "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df2, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 2L, + "number of right rows" -> 6L, + "number of output rows" -> 8L))) + ) + } + } + } + + test("BroadcastHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = df1.join(broadcast(df2), "key") + testSparkPlanMetrics(df, 2, Map( + 1L -> ("BroadcastHashJoin", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("ShuffledHashJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> ShuffledHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("ShuffledHashJoin", Map( + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) + } + } + } + + test("ShuffledHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> ShuffledHashOuterJoin(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(df2, $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + + val df4 = df1.join(df2, $"key" === $"key2", "outer") + testSparkPlanMetrics(df4, 1, Map( + 0L -> ("ShuffledHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 7L))) + ) + } + } + + test("BroadcastHashOuterJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastHashOuterJoin(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) + + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) + } + } + + test("BroadcastNestedLoopJoin metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 2L, + "number of output rows" -> 12L))) + ) + } + } + } + + test("BroadcastLeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastLeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinHash metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinHash(nodeId = 0) + val df = df1.join(df2, $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("LeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("LeftSemiJoinBNL metrics") { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinBNL(nodeId = 0) + val df = df1.join(df2, $"key" < $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("LeftSemiJoinBNL", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + } + + test("CartesianProduct metrics") { + val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("CartesianProduct", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 12L, // right is read 6 times + "number of output rows" -> 12L))) + ) + } + } + + test("save metrics") { + withTempPath { file => + val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + TestData.person.select('name).write.format("json").save(file.getAbsolutePath) + TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq(2L)) + } + } + } private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) From 520ad44b17f72e6465bf990f64b4e289f8a83447 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 11 Aug 2015 12:49:47 -0700 Subject: [PATCH 0167/1824] [SPARK-9750] [MLLIB] Improve equals on SparseMatrix and DenseMatrix Adds unit test for `equals` on `mllib.linalg.Matrix` class and `equals` to both `SparseMatrix` and `DenseMatrix`. Supports equality testing between `SparseMatrix` and `DenseMatrix`. mengxr Author: Feynman Liang Closes #8042 from feynmanliang/SPARK-9750 and squashes the following commits: bb70d5e [Feynman Liang] Breeze compare for dense matrices as well, in case other is sparse ab6f3c8 [Feynman Liang] Sparse matrix compare for equals 22782df [Feynman Liang] Add equality based on matrix semantics, not representation 78f9426 [Feynman Liang] Add casts 43d28fa [Feynman Liang] Fix failing test 6416fa0 [Feynman Liang] Add failing sparse matrix equals tests --- .../apache/spark/mllib/linalg/Matrices.scala | 8 ++++++-- .../spark/mllib/linalg/MatricesSuite.scala | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 1c858348bf20..1139ce36d50b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -257,8 +257,7 @@ class DenseMatrix( this(numRows, numCols, values, false) override def equals(o: Any): Boolean = o match { - case m: DenseMatrix => - m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray) + case m: Matrix => toBreeze == m.toBreeze case _ => false } @@ -519,6 +518,11 @@ class SparseMatrix( rowIndices: Array[Int], values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + override def equals(o: Any): Boolean = o match { + case m: Matrix => toBreeze == m.toBreeze + case _ => false + } + private[mllib] def toBreeze: BM[Double] = { if (!isTransposed) { new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a270ba2562db..bfd6d5495f5e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -74,6 +74,24 @@ class MatricesSuite extends SparkFunSuite { } } + test("equals") { + val dm1 = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)) + assert(dm1 === dm1) + assert(dm1 !== dm1.transpose) + + val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0)) + assert(dm1 === dm2.transpose) + + val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm1) + assert(sm1 === dm1) + assert(sm1 !== sm1.transpose) + + val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse + assert(sm1 === sm2.transpose) + assert(sm1 === dm2.transpose) + } + test("matrix copies are deep copies") { val m = 3 val n = 2 From 2a3be4ddf9d9527353f07ea0ab204ce17dbcba9a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 11 Aug 2015 14:02:23 -0700 Subject: [PATCH 0168/1824] [SPARK-7726] Add import so Scaladoc doesn't fail. This is another import needed so Scala 2.11 doc generation doesn't fail. See SPARK-7726 for more detail. I tested this locally and the 2.11 install goes from failing to succeeding with this patch. Author: Patrick Wendell Closes #8095 from pwendell/scaladoc. --- .../spark/network/shuffle/protocol/mesos/RegisterDriver.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java index 1c28fc1dff24..94a61d6caadc 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java @@ -23,6 +23,9 @@ import org.apache.spark.network.protocol.Encoders; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * A message sent from the driver to register with the MesosExternalShuffleService. */ From 00c02728a6c6c4282c389ca90641dd78dd5e3d32 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 11 Aug 2015 14:04:09 -0700 Subject: [PATCH 0169/1824] [SPARK-9814] [SQL] EqualNotNull not passing to data sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: hyukjinkwon Author: ê¶Œí˜ì§„ Closes #8096 from HyukjinKwon/master. --- .../sql/execution/datasources/DataSourceStrategy.scala | 5 +++++ .../scala/org/apache/spark/sql/sources/filters.scala | 9 +++++++++ .../org/apache/spark/sql/sources/FilteredScanSuite.scala | 1 + 3 files changed, 15 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 78a4acdf4b1b..2a4c40db8bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -349,6 +349,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) + case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => + Some(sources.EqualNullSafe(a.name, v)) + case expressions.EqualNullSafe(Literal(v, _), a: Attribute) => + Some(sources.EqualNullSafe(a.name, v)) + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 4d942e4f9287..3780cbbcc963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -36,6 +36,15 @@ abstract class Filter */ case class EqualTo(attribute: String, value: Any) extends Filter +/** + * Performs equality comparison, similar to [[EqualTo]]. However, this differs from [[EqualTo]] + * in that it returns `true` (rather than NULL) if both inputs are NULL, and `false` + * (rather than NULL) if one of the input is NULL and the other is not NULL. + * + * @since 1.5.0 + */ +case class EqualNullSafe(attribute: String, value: Any) extends Filter + /** * A filter that evaluates to `true` iff the attribute evaluates to a value * greater than `value`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 81b3a0f0c5b3..5ef365797eac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -56,6 +56,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL // Predicate test on integer column def translateFilterOnA(filter: Filter): Int => Boolean = filter match { case EqualTo("a", v) => (a: Int) => a == v + case EqualNullSafe("a", v) => (a: Int) => a == v case LessThan("a", v: Int) => (a: Int) => a < v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case GreaterThan("a", v: Int) => (a: Int) => a > v From f16bc68dfb25c7b746ae031a57840ace9bafa87f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 11 Aug 2015 14:06:23 -0700 Subject: [PATCH 0170/1824] [SPARK-9824] [CORE] Fix the issue that InternalAccumulator leaks WeakReference `InternalAccumulator.create` doesn't call `registerAccumulatorForCleanup` to register itself with ContextCleaner, so `WeakReference`s for these accumulators in `Accumulators.originals` won't be removed. This PR added `registerAccumulatorForCleanup` for internal accumulators to avoid the memory leak. Author: zsxwing Closes #8108 from zsxwing/internal-accumulators-leak. --- .../scala/org/apache/spark/Accumulators.scala | 22 +++++++++++-------- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../org/apache/spark/AccumulatorSuite.scala | 3 ++- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 064246dfa7fc..c39c8667d013 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -382,14 +382,18 @@ private[spark] object InternalAccumulator { * add to the same set of accumulators. We do this to report the distribution of accumulator * values across all tasks within each stage. */ - def create(): Seq[Accumulator[Long]] = { - Seq( - // Execution memory refers to the memory used by internal data structures created - // during shuffles, aggregations and joins. The value of this accumulator should be - // approximately the sum of the peak sizes across all such data structures created - // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. - new Accumulator( - 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) - ) ++ maybeTestAccumulator.toSeq + def create(sc: SparkContext): Seq[Accumulator[Long]] = { + val internalAccumulators = Seq( + // Execution memory refers to the memory used by internal data structures created + // during shuffles, aggregations and joins. The value of this accumulator should be + // approximately the sum of the peak sizes across all such data structures created + // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort. + new Accumulator( + 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true) + ) ++ maybeTestAccumulator.toSeq + internalAccumulators.foreach { accumulator => + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator)) + } + internalAccumulators } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index de05ee256dbf..1cf06856ffbc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -81,7 +81,7 @@ private[spark] abstract class Stage( * accumulators here again will override partial values from the finished tasks. */ def resetInternalAccumulators(): Unit = { - _internalAccumulators = InternalAccumulator.create() + _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) } /** diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 48f549575f4d..0eb2293a9d06 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -160,7 +160,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("internal accumulators in TaskContext") { - val accums = InternalAccumulator.create() + sc = new SparkContext("local", "test") + val accums = InternalAccumulator.create(sc) val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums) val internalMetricsToAccums = taskContext.internalMetricsToAccumulators val collectedInternalAccums = taskContext.collectInternalAccumulators() From 423cdfd83d7fd02a4f8cf3e714db913fd3f9ca09 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 11 Aug 2015 14:08:09 -0700 Subject: [PATCH 0171/1824] Closes #1290 Closes #4934 From be3e27164133025db860781bd5cdd3ca233edd21 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 11 Aug 2015 14:21:53 -0700 Subject: [PATCH 0172/1824] [SPARK-9788] [MLLIB] Fix LDA Binary Compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Add “asymmetricDocConcentration†and revert docConcentration changes. If the (internal) doc concentration vector is a single value, “getDocConcentration" returns it. If it is a constant vector, getDocConcentration returns the first item, and fails otherwise. 2. Give `LDAModel.gammaShape` a default value in `LDAModel` concrete class constructors. jkbradley Author: Feynman Liang Closes #8077 from feynmanliang/SPARK-9788 and squashes the following commits: 6b07bc8 [Feynman Liang] Code review changes 9d6a71e [Feynman Liang] Add asymmetricAlpha alias bf4e685 [Feynman Liang] Asymmetric docConcentration 4cab972 [Feynman Liang] Default gammaShape --- .../apache/spark/mllib/clustering/LDA.scala | 27 ++++++++++++++++-- .../spark/mllib/clustering/LDAModel.scala | 11 ++++---- .../spark/mllib/clustering/LDAOptimizer.scala | 28 +++++++++---------- .../spark/mllib/clustering/LDASuite.scala | 4 +-- 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index ab124e6d77c5..0fc9b1ac4d71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -79,7 +79,24 @@ class LDA private ( * * This is the parameter to a Dirichlet distribution. */ - def getDocConcentration: Vector = this.docConcentration + def getAsymmetricDocConcentration: Vector = this.docConcentration + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This method assumes the Dirichlet distribution is symmetric and can be described by a single + * [[Double]] parameter. It should fail if docConcentration is asymmetric. + */ + def getDocConcentration: Double = { + val parameter = docConcentration(0) + if (docConcentration.size == 1) { + parameter + } else { + require(docConcentration.toArray.forall(_ == parameter)) + parameter + } + } /** * Concentration parameter (commonly named "alpha") for the prior placed on documents' @@ -106,18 +123,22 @@ class LDA private ( * [[https://github.com/Blei-Lab/onlineldavb]]. */ def setDocConcentration(docConcentration: Vector): this.type = { + require(docConcentration.size > 0, "docConcentration must have > 0 elements") this.docConcentration = docConcentration this } - /** Replicates Double to create a symmetric prior */ + /** Replicates a [[Double]] docConcentration to create a symmetric prior. */ def setDocConcentration(docConcentration: Double): this.type = { this.docConcentration = Vectors.dense(docConcentration) this } + /** Alias for [[getAsymmetricDocConcentration]] */ + def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration + /** Alias for [[getDocConcentration]] */ - def getAlpha: Vector = getDocConcentration + def getAlpha: Double = getDocConcentration /** Alias for [[setDocConcentration()]] */ def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 33babda69bbb..5dc637ebdc13 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -27,7 +27,6 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.broadcast.Broadcast import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -190,7 +189,8 @@ class LocalLDAModel private[clustering] ( val topics: Matrix, override val docConcentration: Vector, override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable { + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel with Serializable { override def k: Int = topics.numCols @@ -455,8 +455,9 @@ class DistributedLDAModel private[clustering] ( val vocabSize: Int, override val docConcentration: Vector, override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double, - private[spark] val iterationTimes: Array[Double]) extends LDAModel { + private[spark] val iterationTimes: Array[Double], + override protected[clustering] val gammaShape: Double = 100) + extends LDAModel { import LDA._ @@ -756,7 +757,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, - docConcentration, topicConcentration, gammaShape, iterationTimes) + docConcentration, topicConcentration, iterationTimes, gammaShape) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index afba2866c704..a0008f9c99ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -95,10 +95,8 @@ final class EMLDAOptimizer extends LDAOptimizer { * Compute bipartite term/doc graph. */ override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { - val docConcentration = lda.getDocConcentration(0) - require({ - lda.getDocConcentration.toArray.forall(_ == docConcentration) - }, "EMLDAOptimizer currently only supports symmetric document-topic priors") + // EMLDAOptimizer currently only supports symmetric document-topic priors + val docConcentration = lda.getDocConcentration val topicConcentration = lda.getTopicConcentration val k = lda.getK @@ -209,11 +207,11 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") this.graphCheckpointer.deleteAllCheckpoints() - // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal - // conversion + // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in + // LDAModel.toLocal conversion new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, - 100, iterationTimes) + iterationTimes) } } @@ -378,18 +376,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size - this.alpha = if (lda.getDocConcentration.size == 1) { - if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) + this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) { + if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) else { - require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha") - Vectors.dense(Array.fill(k)(lda.getDocConcentration(0))) + require(lda.getAsymmetricDocConcentration(0) >= 0, + s"all entries in alpha must be >=0, got: $alpha") + Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0))) } } else { - require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha") - lda.getDocConcentration.foreachActive { case (_, x) => + require(lda.getAsymmetricDocConcentration.size == k, + s"alpha must have length k, got: $alpha") + lda.getAsymmetricDocConcentration.foreachActive { case (_, x) => require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha") } - lda.getDocConcentration + lda.getAsymmetricDocConcentration } this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration this.randomGenerator = new Random(lda.getSeed) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index fdc2554ab853..ce6a8eb8e8c4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -160,8 +160,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("setter alias") { val lda = new LDA().setAlpha(2.0).setBeta(3.0) - assert(lda.getAlpha.toArray.forall(_ === 2.0)) - assert(lda.getDocConcentration.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0)) + assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0)) assert(lda.getBeta === 3.0) assert(lda.getTopicConcentration === 3.0) } From 017b5de07ef6cff249e984a2ab781c520249ac76 Mon Sep 17 00:00:00 2001 From: Sudhakar Thota Date: Tue, 11 Aug 2015 14:31:51 -0700 Subject: [PATCH 0173/1824] [SPARK-8925] [MLLIB] Add @since tags to mllib.util Went thru the history of changes the file MLUtils.scala and picked up the version that the change went in. Author: Sudhakar Thota Author: Sudhakar Thota Closes #7436 from sthota2014/SPARK-8925_thotas. --- .../org/apache/spark/mllib/util/MLUtils.scala | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 7c5cfa7bd84c..26eb84a8dc0b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -64,6 +64,7 @@ object MLUtils { * feature dimensions. * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] + * @since 1.0.0 */ def loadLibSVMFile( sc: SparkContext, @@ -113,7 +114,10 @@ object MLUtils { } // Convenient methods for `loadLibSVMFile`. - + + /** + * @since 1.0.0 + */ @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -126,6 +130,7 @@ object MLUtils { /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. + * @since 1.0.0 */ def loadLibSVMFile( sc: SparkContext, @@ -133,6 +138,9 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) + /** + * @since 1.0.0 + */ @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -141,6 +149,9 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures) + /** + * @since 1.0.0 + */ @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -151,6 +162,7 @@ object MLUtils { /** * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. + * @since 1.0.0 */ def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = loadLibSVMFile(sc, path, -1) @@ -181,12 +193,14 @@ object MLUtils { * @param path file or directory path in any Hadoop-supported file system URI * @param minPartitions min number of partitions * @return vectors stored as an RDD[Vector] + * @since 1.1.0 */ def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] = sc.textFile(path, minPartitions).map(Vectors.parse) /** * Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions. + * @since 1.1.0 */ def loadVectors(sc: SparkContext, path: String): RDD[Vector] = sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse) @@ -197,6 +211,7 @@ object MLUtils { * @param path file or directory path in any Hadoop-supported file system URI * @param minPartitions min number of partitions * @return labeled points stored as an RDD[LabeledPoint] + * @since 1.1.0 */ def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = sc.textFile(path, minPartitions).map(LabeledPoint.parse) @@ -204,6 +219,7 @@ object MLUtils { /** * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of * partitions. + * @since 1.1.0 */ def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) @@ -220,6 +236,7 @@ object MLUtils { * * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. + * @since 1.0.0 */ @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1") def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { @@ -241,6 +258,7 @@ object MLUtils { * * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. + * @since 1.0.0 */ @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1") def saveLabeledData(data: RDD[LabeledPoint], dir: String) { @@ -253,6 +271,7 @@ object MLUtils { * Return a k element array of pairs of RDDs with the first element of each pair * containing the training data, a complement of the validation data and the second * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. + * @since 1.0.0 */ @Experimental def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { @@ -268,6 +287,7 @@ object MLUtils { /** * Returns a new vector with `1.0` (bias) appended to the input vector. + * @since 1.0.0 */ def appendBias(vector: Vector): Vector = { vector match { From 736af95bd0c41723d455246b634a0fb68b38a7c7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 14:52:52 -0700 Subject: [PATCH 0174/1824] [HOTFIX] Fix style error caused by 017b5de --- mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 26eb84a8dc0b..11ed23176fc1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -114,7 +114,7 @@ object MLUtils { } // Convenient methods for `loadLibSVMFile`. - + /** * @since 1.0.0 */ From 5a5bbc29961630d649d4bd4acd5d19eb537b5fd0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 11 Aug 2015 16:33:08 -0700 Subject: [PATCH 0175/1824] [SPARK-9074] [LAUNCHER] Allow arbitrary Spark args to be set. This change allows any Spark argument to be added to the app to be started using SparkLauncher. Known arguments are properly validated, while unknown arguments are allowed so that the library can launch newer Spark versions (in case SPARK_HOME points at one). Author: Marcelo Vanzin Closes #7975 from vanzin/SPARK-9074 and squashes the following commits: b5e451a [Marcelo Vanzin] [SPARK-9074] [launcher] Allow arbitrary Spark args to be set. --- .../apache/spark/launcher/SparkLauncher.java | 101 +++++++++++++++++- .../launcher/SparkSubmitCommandBuilder.java | 2 +- .../spark/launcher/SparkLauncherSuite.java | 50 +++++++++ 3 files changed, 150 insertions(+), 3 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index c0f89c923069..03c9358bc865 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -20,12 +20,13 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import static org.apache.spark.launcher.CommandBuilderUtils.*; -/** +/** * Launcher for Spark applications. *

    * Use this class to start Spark applications programmatically. The class uses a builder pattern @@ -57,7 +58,8 @@ public class SparkLauncher { /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; - private final SparkSubmitCommandBuilder builder; + // Visible for testing. + final SparkSubmitCommandBuilder builder; public SparkLauncher() { this(null); @@ -187,6 +189,73 @@ public SparkLauncher setMainClass(String mainClass) { return this; } + /** + * Adds a no-value argument to the Spark invocation. If the argument is known, this method + * validates whether the argument is indeed a no-value argument, and throws an exception + * otherwise. + *

    + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param arg Argument to add. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String arg) { + SparkSubmitOptionParser validator = new ArgumentValidator(false); + validator.parse(Arrays.asList(arg)); + builder.sparkArgs.add(arg); + return this; + } + + /** + * Adds an argument with a value to the Spark invocation. If the argument name corresponds to + * a known argument, the code validates that the argument actually expects a value, and throws + * an exception otherwise. + *

    + * It is safe to add arguments modified by other methods in this class (such as + * {@link #setMaster(String)} - the last invocation will be the one to take effect. + *

    + * Use this method with caution. It is possible to create an invalid Spark command by passing + * unknown arguments to this method, since those are allowed for forward compatibility. + * + * @param name Name of argument to add. + * @param value Value of the argument. + * @return This launcher. + */ + public SparkLauncher addSparkArg(String name, String value) { + SparkSubmitOptionParser validator = new ArgumentValidator(true); + if (validator.MASTER.equals(name)) { + setMaster(value); + } else if (validator.PROPERTIES_FILE.equals(name)) { + setPropertiesFile(value); + } else if (validator.CONF.equals(name)) { + String[] vals = value.split("=", 2); + setConf(vals[0], vals[1]); + } else if (validator.CLASS.equals(name)) { + setMainClass(value); + } else if (validator.JARS.equals(name)) { + builder.jars.clear(); + for (String jar : value.split(",")) { + addJar(jar); + } + } else if (validator.FILES.equals(name)) { + builder.files.clear(); + for (String file : value.split(",")) { + addFile(file); + } + } else if (validator.PY_FILES.equals(name)) { + builder.pyFiles.clear(); + for (String file : value.split(",")) { + addPyFile(file); + } + } else { + validator.parse(Arrays.asList(name, value)); + builder.sparkArgs.add(name); + builder.sparkArgs.add(value); + } + return this; + } + /** * Adds command line arguments for the application. * @@ -277,4 +346,32 @@ public Process launch() throws IOException { return pb.start(); } + private static class ArgumentValidator extends SparkSubmitOptionParser { + + private final boolean hasValue; + + ArgumentValidator(boolean hasValue) { + this.hasValue = hasValue; + } + + @Override + protected boolean handle(String opt, String value) { + if (value == null && hasValue) { + throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + } + return true; + } + + @Override + protected boolean handleUnknown(String opt) { + // Do not fail on unknown arguments, to support future arguments added to SparkSubmit. + return true; + } + + protected void handleExtraArgs(List extra) { + // No op. + } + + }; + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 87c43aa9980e..4f354cedee66 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -76,7 +76,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { "spark-internal"); } - private final List sparkArgs; + final List sparkArgs; private final boolean printHelp; /** diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 252d5abae1ca..d0c26dd05679 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -20,6 +20,7 @@ import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -35,8 +36,54 @@ public class SparkLauncherSuite { private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class); + @Test + public void testSparkArgumentHandling() throws Exception { + SparkLauncher launcher = new SparkLauncher() + .setSparkHome(System.getProperty("spark.test.home")); + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + + launcher.addSparkArg(opts.HELP); + try { + launcher.addSparkArg(opts.PROXY_USER); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg(opts.PROXY_USER, "someUser"); + try { + launcher.addSparkArg(opts.HELP, "someValue"); + fail("Expected IllegalArgumentException."); + } catch (IllegalArgumentException e) { + // Expected. + } + + launcher.addSparkArg("--future-argument"); + launcher.addSparkArg("--future-argument", "someValue"); + + launcher.addSparkArg(opts.MASTER, "myMaster"); + assertEquals("myMaster", launcher.builder.master); + + launcher.addJar("foo"); + launcher.addSparkArg(opts.JARS, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.jars); + + launcher.addFile("foo"); + launcher.addSparkArg(opts.FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.files); + + launcher.addPyFile("foo"); + launcher.addSparkArg(opts.PY_FILES, "bar"); + assertEquals(Arrays.asList("bar"), launcher.builder.pyFiles); + + launcher.setConf("spark.foo", "foo"); + launcher.addSparkArg(opts.CONF, "spark.foo=bar"); + assertEquals("bar", launcher.builder.conf.get("spark.foo")); + } + @Test public void testChildProcLauncher() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); Map env = new HashMap(); env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); @@ -44,9 +91,12 @@ public void testChildProcLauncher() throws Exception { .setSparkHome(System.getProperty("spark.test.home")) .setMaster("local") .setAppResource("spark-internal") + .addSparkArg(opts.CONF, + String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, "-Dfoo=bar -Dtest.name=-testChildProcLauncher") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) + .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) .addAppArgs("proc"); final Process app = launcher.launch(); From afa757c98c537965007cad4c61c436887f3ac6a6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 11 Aug 2015 18:08:49 -0700 Subject: [PATCH 0176/1824] [SPARK-9849] [SQL] DirectParquetOutputCommitter qualified name should be backward compatible DirectParquetOutputCommitter was moved in SPARK-9763. However, users can explicitly set the class as a config option, so we must be able to resolve the old committer qualified name. Author: Reynold Xin Closes #8114 from rxin/SPARK-9849. --- .../datasources/parquet/ParquetRelation.scala | 7 +++++ .../datasources/parquet/ParquetIOSuite.scala | 27 ++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 4086a139bed7..c71c69b6e80b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -209,6 +209,13 @@ private[sql] class ParquetRelation( override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) + // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible + val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) + if (committerClassname == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") { + conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[DirectParquetOutputCommitter].getCanonicalName) + } + val committerClass = conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index ee925afe0850..cb166349fdb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -390,7 +390,32 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } - test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overriden") { + test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { + val clonedConf = new Configuration(configuration) + + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } + + + test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => val clonedConf = new Configuration(configuration) From ca8f70e9d473d2c81866f3c330cc6545c33bdac7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 11 Aug 2015 20:46:58 -0700 Subject: [PATCH 0177/1824] [SPARK-9649] Fix flaky test MasterSuite again - disable REST The REST server is not actually used in most tests and so we can disable it. It is a source of flakiness because it tries to bind to a specific port in vain. There was also some code that avoided the shuffle service in tests. This is actually not necessary because the shuffle service is already off by default. Author: Andrew Or Closes #8084 from andrewor14/fix-master-suite-again. --- pom.xml | 1 + project/SparkBuild.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/pom.xml b/pom.xml index 8942836a7da1..cfd7d32563f2 100644 --- a/pom.xml +++ b/pom.xml @@ -1895,6 +1895,7 @@ ${project.build.directory}/tmp ${spark.test.home} 1 + false false false true diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index cad7067ade8c..74f815f941d5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -546,6 +546,7 @@ object TestSettings { javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", + javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", From 3ef0f32928fc383ad3edd5ad167212aeb9eba6e1 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 11 Aug 2015 21:16:48 -0700 Subject: [PATCH 0178/1824] [SPARK-1517] Refactor release scripts to facilitate nightly publishing This update contains some code changes to the release scripts that allow easier nightly publishing. I've been using these new scripts on Jenkins for cutting and publishing nightly snapshots for the last month or so, and it has been going well. I'd like to get them merged back upstream so this can be maintained by the community. The main changes are: 1. Separates the release tagging from various build possibilities for an already tagged release (`release-tag.sh` and `release-build.sh`). 2. Allow for injecting credentials through the environment, including GPG keys. This is then paired with secure key injection in Jenkins. 3. Support for copying build results to a remote directory, and also "rotating" results, e.g. the ability to keep the last N copies of binary or doc builds. I'm happy if anyone wants to take a look at this - it's not user facing but an internal utility used for generating releases. Author: Patrick Wendell Closes #7411 from pwendell/release-script-updates and squashes the following commits: 74f9beb [Patrick Wendell] Moving maven build command to a variable 233ce85 [Patrick Wendell] [SPARK-1517] Refactor release scripts to facilitate nightly publishing --- dev/create-release/create-release.sh | 267 ---------------------- dev/create-release/release-build.sh | 321 +++++++++++++++++++++++++++ dev/create-release/release-tag.sh | 79 +++++++ 3 files changed, 400 insertions(+), 267 deletions(-) delete mode 100755 dev/create-release/create-release.sh create mode 100755 dev/create-release/release-build.sh create mode 100755 dev/create-release/release-tag.sh diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh deleted file mode 100755 index 4311c8c9e4ca..000000000000 --- a/dev/create-release/create-release.sh +++ /dev/null @@ -1,267 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Quick-and-dirty automation of making maven and binary releases. Not robust at all. -# Publishes releases to Maven and packages/copies binary release artifacts. -# Expects to be run in a totally empty directory. -# -# Options: -# --skip-create-release Assume the desired release tag already exists -# --skip-publish Do not publish to Maven central -# --skip-package Do not package and upload binary artifacts -# Would be nice to add: -# - Send output to stderr and have useful logging in stdout - -# Note: The following variables must be set before use! -ASF_USERNAME=${ASF_USERNAME:-pwendell} -ASF_PASSWORD=${ASF_PASSWORD:-XXX} -GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} -GIT_BRANCH=${GIT_BRANCH:-branch-1.0} -RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} -# Allows publishing under a different version identifier than -# was present in the actual release sources (e.g. rc-X) -PUBLISH_VERSION=${PUBLISH_VERSION:-$RELEASE_VERSION} -NEXT_VERSION=${NEXT_VERSION:-1.2.1} -RC_NAME=${RC_NAME:-rc2} - -M2_REPO=~/.m2/repository -SPARK_REPO=$M2_REPO/org/apache/spark -NEXUS_ROOT=https://repository.apache.org/service/local/staging -NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads - -if [ -z "$JAVA_HOME" ]; then - echo "Error: JAVA_HOME is not set, cannot proceed." - exit -1 -fi -JAVA_7_HOME=${JAVA_7_HOME:-$JAVA_HOME} - -set -e - -GIT_TAG=v$RELEASE_VERSION-$RC_NAME - -if [[ ! "$@" =~ --skip-create-release ]]; then - echo "Creating release commit and publishing to Apache repository" - # Artifact publishing - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ - -b $GIT_BRANCH - pushd spark - export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - - # Create release commits and push them to github - # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build - # or before we coin the release commit. This helps avoid races where - # other people add commits to this branch while we are in the middle of building. - cur_ver="${RELEASE_VERSION}-SNAPSHOT" - rel_ver="${RELEASE_VERSION}" - next_ver="${NEXT_VERSION}-SNAPSHOT" - - old="^\( \{2,4\}\)${cur_ver}<\/version>$" - new="\1${rel_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - git commit -a -m "Preparing Spark release $GIT_TAG" - echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH" - git tag $GIT_TAG - - old="^\( \{2,4\}\)${rel_ver}<\/version>$" - new="\1${next_ver}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/$old/$new/" {} - find . -name package.scala | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - git commit -a -m "Preparing development version $next_ver" - git push origin $GIT_TAG - git push origin HEAD:$GIT_BRANCH - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-publish ]]; then - git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git - pushd spark - git checkout --force $GIT_TAG - - # Substitute in case published version is different than released - old="^\( \{2,4\}\)${RELEASE_VERSION}<\/version>$" - new="\1${PUBLISH_VERSION}<\/version>" - find . -name pom.xml | grep -v dev | xargs -I {} sed -i \ - -e "s/${old}/${new}/" {} - - # Using Nexus API documented here: - # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" - - rm -rf $SPARK_REPO - - build/mvn -DskipTests -Pyarn -Phive \ - -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.11 - - build/mvn -DskipTests -Pyarn -Phive \ - -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - clean install - - ./dev/change-scala-version.sh 2.10 - - pushd $SPARK_REPO - - # Remove any extra files generated during install - find . -type f |grep -v \.jar |grep -v \.pom | xargs rm - - echo "Creating hash and signature files" - for file in $(find . -type f) - do - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi - shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 - done - - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done - - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $GIT_TAG (published as $PUBLISH_VERSION)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" - - popd - popd - rm -rf spark -fi - -if [[ ! "$@" =~ --skip-package ]]; then - # Source and binary tarballs - echo "Packaging release tarballs" - git clone https://git-wip-us.apache.org/repos/asf/spark.git - cd spark - git checkout --force $GIT_TAG - release_hash=`git rev-parse HEAD` - - rm .gitignore - rm -rf .git - cd .. - - cp -r spark spark-$RELEASE_VERSION - tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \ - spark-$RELEASE_VERSION.tgz.sha - rm -rf spark-$RELEASE_VERSION - - # Updated for each binary build - make_binary_release() { - NAME=$1 - FLAGS=$2 - ZINC_PORT=$3 - cp -r spark spark-$RELEASE_VERSION-bin-$NAME - - cd spark-$RELEASE_VERSION-bin-$NAME - - # TODO There should probably be a flag to make-distribution to allow 2.11 support - if [[ $FLAGS == *scala-2.11* ]]; then - ./dev/change-scala-version.sh 2.11 - fi - - export ZINC_PORT=$ZINC_PORT - echo "Creating distribution: $NAME ($FLAGS)" - ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ - ../binary-release-$NAME.log - cd .. - cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . - - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \ - --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \ - --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.md5 - echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \ - SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \ - spark-$RELEASE_VERSION-bin-$NAME.tgz.sha - } - - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds - # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & - make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & - make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & - make_binary_release "mapr3" "-Pmapr3 -Psparkr -Phive -Phive-thriftserver" "3035" & - make_binary_release "mapr4" "-Pmapr4 -Psparkr -Pyarn -Phive -Phive-thriftserver" "3036" & - make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & - wait - rm -rf spark-$RELEASE_VERSION-bin-*/ - - # Copy data - echo "Copying release tarballs" - rc_folder=spark-$RELEASE_VERSION-$RC_NAME - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_folder - scp spark-* \ - $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ - - # Docs - cd spark - sbt/sbt clean - cd docs - # Compile docs with Java 7 to use nicer format - JAVA_HOME="$JAVA_7_HOME" PRODUCTION=1 RELEASE_VERSION="$RELEASE_VERSION" jekyll build - echo "Copying release documentation" - rc_docs_folder=${rc_folder}-docs - ssh $ASF_USERNAME@people.apache.org \ - mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder - rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder - - echo "Release $RELEASE_VERSION completed:" - echo "Git tag:\t $GIT_TAG" - echo "Release commit:\t $release_hash" - echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" - echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" -fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh new file mode 100755 index 000000000000..399c73e7bf6b --- /dev/null +++ b/dev/create-release/release-build.sh @@ -0,0 +1,321 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: release-build.sh +Creates build deliverables from a Spark commit. + +Top level targets are + package: Create binary packages and copy them to people.apache + docs: Build docs and copy them to people.apache + publish-snapshot: Publish snapshot release to Apache snapshots + publish-release: Publish a release to Apache release repo + +All other inputs are environment variables + +GIT_REF - Release tag or commit to build from +SPARK_VERSION - Release identifier used when publishing +SPARK_PACKAGE_VERSION - Release identifier in top level package directory +REMOTE_PARENT_DIR - Parent in which to create doc or release builds. +REMOTE_PARENT_MAX_LENGTH - If set, parent directory will be cleaned to only + have this number of subdirectories (by deleting old ones). WARNING: This deletes data. + +ASF_USERNAME - Username of ASF committer account +ASF_PASSWORD - Password of ASF committer account +ASF_RSA_KEY - RSA private key file for ASF committer account + +GPG_KEY - GPG key used to sign release artifacts +GPG_PASSPHRASE - Passphrase for GPG key +EOF + exit 1 +} + +set -e + +if [ $# -eq 0 ]; then + exit_with_usage +fi + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_RSA_KEY GPG_PASSPHRASE GPG_KEY; do + if [ -z "${!env}" ]; then + echo "ERROR: $env must be set to run this script" + exit_with_usage + fi +done + +# Commit ref to checkout when building +GIT_REF=${GIT_REF:-master} + +# Destination directory parent on remote server +REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} + +SSH="ssh -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" +GPG="gpg --no-tty --batch" +NEXUS_ROOT=https://repository.apache.org/service/local/staging +NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads +BASE_DIR=$(pwd) + +MVN="build/mvn --force" +PUBLISH_PROFILES="-Pyarn -Phive -Phadoop-2.2" +PUBLISH_PROFILES="$PUBLISH_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" + +rm -rf spark +git clone https://git-wip-us.apache.org/repos/asf/spark.git +cd spark +git checkout $GIT_REF +git_hash=`git rev-parse --short HEAD` +echo "Checked out Spark git hash $git_hash" + +if [ -z "$SPARK_VERSION" ]; then + SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ + | grep -v INFO | grep -v WARNING | grep -v Download) +fi + +if [ -z "$SPARK_PACKAGE_VERSION" ]; then + SPARK_PACKAGE_VERSION="${SPARK_VERSION}-$(date +%Y_%m_%d_%H_%M)-${git_hash}" +fi + +DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" +USER_HOST="$ASF_USERNAME@people.apache.org" + +rm .gitignore +rm -rf .git +cd .. + +if [ -n "$REMOTE_PARENT_MAX_LENGTH" ]; then + old_dirs=$($SSH $USER_HOST ls -t $REMOTE_PARENT_DIR | tail -n +$REMOTE_PARENT_MAX_LENGTH) + for old_dir in $old_dirs; do + echo "Removing directory: $old_dir" + $SSH $USER_HOST rm -r $REMOTE_PARENT_DIR/$old_dir + done +fi + +if [[ "$1" == "package" ]]; then + # Source and binary tarballs + echo "Packaging release tarballs" + cp -r spark spark-$SPARK_VERSION + tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ + --detach-sig spark-$SPARK_VERSION.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ + spark-$SPARK_VERSION.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha + rm -rf spark-$SPARK_VERSION + + # Updated for each binary build + make_binary_release() { + NAME=$1 + FLAGS=$2 + ZINC_PORT=$3 + cp -r spark spark-$SPARK_VERSION-bin-$NAME + + cd spark-$SPARK_VERSION-bin-$NAME + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-scala-version.sh 2.11 + fi + + export ZINC_PORT=$ZINC_PORT + echo "Creating distribution: $NAME ($FLAGS)" + ./make-distribution.sh --name $NAME --tgz $FLAGS -DzincPort=$ZINC_PORT 2>&1 > \ + ../binary-release-$NAME.log + cd .. + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ + --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ + spark-$SPARK_VERSION-bin-$NAME.tgz.sha + } + + # TODO: Check exit codes of children here: + # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process + + # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds + # share the same Zinc server. + make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & + make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & + make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & + make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & + make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & + make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn" "3037" & + make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn" "3038" & + wait + rm -rf spark-$SPARK_VERSION-bin-*/ + + # Copy data + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" + echo "Copying release tarballs to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + rsync -e "$SSH" spark-* $USER_HOST:$dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + exit 0 +fi + +if [[ "$1" == "docs" ]]; then + # Documentation + cd spark + echo "Building Spark docs" + dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-docs" + cd docs + # Compile docs with Java 7 to use nicer format + # TODO: Make configurable to add this: PRODUCTION=1 + PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build + echo "Copying release documentation to $dest_dir" + $SSH $USER_HOST mkdir $dest_dir + echo "Linking /latest to $dest_dir" + $SSH $USER_HOST rm -f "$REMOTE_PARENT_DIR/latest" + $SSH $USER_HOST ln -s $dest_dir "$REMOTE_PARENT_DIR/latest" + rsync -e "$SSH" -r _site/* $USER_HOST:$dest_dir + cd .. + exit 0 +fi + +if [[ "$1" == "publish-snapshot" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Deploying Spark SNAPSHOT at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + if [[ ! $SPARK_VERSION == *"SNAPSHOT"* ]]; then + echo "ERROR: Snapshots must have a version containing SNAPSHOT" + echo "ERROR: You gave version '$SPARK_VERSION'" + exit 1 + fi + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + tmp_settings="tmp-settings.xml" + echo "" > $tmp_settings + echo "apache.snapshots.https$ASF_USERNAME" >> $tmp_settings + echo "$ASF_PASSWORD" >> $tmp_settings + echo "" >> $tmp_settings + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver deploy + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ + -DskipTests $PUBLISH_PROFILES deploy + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + rm $tmp_settings + cd .. + exit 0 +fi + +if [[ "$1" == "publish-release" ]]; then + cd spark + # Publish Spark to Maven release repo + echo "Publishing Spark checkout at '$GIT_REF' ($git_hash)" + echo "Publish version is $SPARK_VERSION" + # Coerce the requested version + $MVN versions:set -DnewVersion=$SPARK_VERSION + + # Using Nexus API documented here: + # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + + tmp_repo=$(mktemp -d spark-repo-XXXXX) + + # Generate random point for Zinc + export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ + -Phive-thriftserver clean install + + ./dev/change-scala-version.sh 2.11 + + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Dscala-2.11 \ + -DskipTests $PUBLISH_PROFILES clean install + + # Clean-up Zinc nailgun process + /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill + + ./dev/change-version-to-2.10.sh + + pushd $tmp_repo/org/apache/spark + + # Remove any extra files generated during install + find . -type f |grep -v \.jar |grep -v \.pom | xargs rm + + echo "Creating hash and signature files" + for file in $(find . -type f) + do + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ + --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + sha1sum $file | cut -f1 -d' ' > $file.sha1 + done + + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + popd + rm -rf $tmp_repo + cd .. + exit 0 +fi + +cd .. +rm -rf spark +echo "ERROR: expects to be called with 'package', 'docs', 'publish-release' or 'publish-snapshot'" diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh new file mode 100755 index 000000000000..b0a3374becc6 --- /dev/null +++ b/dev/create-release/release-tag.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +function exit_with_usage { + cat << EOF +usage: tag-release.sh +Tags a Spark release on a particular branch. + +Inputs are specified with the following environment variables: +ASF_USERNAME - Apache Username +ASF_PASSWORD - Apache Password +GIT_NAME - Name to use with git +GIT_EMAIL - E-mail address to use with git +GIT_BRANCH - Git branch on which to make release +RELEASE_VERSION - Version used in pom files for release +RELEASE_TAG - Name of release tag +NEXT_VERSION - Development version after release +EOF + exit 1 +} + +set -e + +if [[ $@ == *"help"* ]]; then + exit_with_usage +fi + +for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GIT_EMAIL GIT_NAME GIT_BRANCH; do + if [ -z "${!env}" ]; then + echo "$env must be set to run this script" + exit 1 + fi +done + +ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" +MVN="build/mvn --force" + +rm -rf spark +git clone https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO -b $GIT_BRANCH +cd spark + +git config user.name "$GIT_NAME" +git config user.email $GIT_EMAIL + +# Create release version +$MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing Spark release $RELEASE_TAG" +echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" +git tag $RELEASE_TAG + +# TODO: It would be nice to do some verifications here +# i.e. check whether ec2 scripts have the new version + +# Create next version +$MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +git commit -a -m "Preparing development version $NEXT_VERSION" + +# Push changes +git push origin $RELEASE_TAG +git push origin HEAD:$GIT_BRANCH + +cd .. +rm -rf spark From 74a293f4537c6982345166f8883538f81d850872 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 11 Aug 2015 21:26:03 -0700 Subject: [PATCH 0179/1824] [SPARK-9713] [ML] Document SparkR MLlib glm() integration in Spark 1.5 This documents the use of R model formulae in the SparkR guide. Also fixes some bugs in the R api doc. mengxr Author: Eric Liang Closes #8085 from ericl/docs. --- R/pkg/R/generics.R | 4 ++-- R/pkg/R/mllib.R | 8 ++++---- docs/sparkr.md | 37 ++++++++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c43b947129e8..379a78b1d833 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -535,8 +535,8 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) -##' rdname summary -##' @export +#' @rdname summary +#' @export setGeneric("summary", function(x, ...) { standardGeneric("summary") }) # @rdname tojson diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b524d1fd8749..cea3d760d05f 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -56,10 +56,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram #' #' Makes predictions from a model produced by glm(), similarly to R's predict(). #' -#' @param model A fitted MLlib model +#' @param object A fitted MLlib model #' @param newData DataFrame for testing #' @return DataFrame containing predicted values -#' @rdname glm +#' @rdname predict #' @export #' @examples #'\dontrun{ @@ -76,10 +76,10 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param model A fitted MLlib model +#' @param x A fitted MLlib model #' @return a list with a 'coefficient' component, which is the matrix of coefficients. See #' summary.glm for more information. -#' @rdname glm +#' @rdname summary #' @export #' @examples #'\dontrun{ diff --git a/docs/sparkr.md b/docs/sparkr.md index 4385a4eeacd5..7139d16b4a06 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -11,7 +11,8 @@ title: SparkR (R on Spark) SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that supports operations like selection, filtering, aggregation etc. (similar to R data frames, -[dplyr](https://github.com/hadley/dplyr)) but on large datasets. +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. SparkR also supports distributed +machine learning using MLlib. # SparkR DataFrames @@ -230,3 +231,37 @@ head(teenagers) {% endhighlight %}

    + +# Machine Learning + +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) + +# Fit a linear model over the dataset. +model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) 2.2513930 +##Sepal_Width 0.8035609 +##Species_versicolor 1.4587432 +##Species_virginica 1.9468169 + +# Make predictions based on the model. +predictions <- predict(model, newData = df) +head(select(predictions, "Sepal_Length", "prediction")) +## Sepal_Length prediction +##1 5.1 5.063856 +##2 4.9 4.662076 +##3 4.7 4.822788 +##4 4.6 4.742432 +##5 5.0 5.144212 +##6 5.4 5.385281 +{% endhighlight %} +
    From c3e9a120e33159fb45cd99f3a55fc5cf16cd7c6c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 11 Aug 2015 22:45:18 -0700 Subject: [PATCH 0180/1824] [SPARK-9831] [SQL] fix serialization with empty broadcast Author: Davies Liu Closes #8117 from davies/fix_serialization and squashes the following commits: d21ac71 [Davies Liu] fix serialization with empty broadcast --- .../sql/execution/joins/HashedRelation.scala | 2 +- .../execution/joins/HashedRelationSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index c1bc7947aa39..076afe6e4e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -299,7 +299,7 @@ private[joins] final class UnsafeHashedRelation( binaryMap = new BytesToBytesMap( taskMemoryManager, shuffleMemoryManager, - nKeys * 2, // reduce hash collision + (nKeys * 1.5 + 1).toInt, // reduce hash collision pageSizeBytes) var i = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index a1fa2c3864bd..c635b2d51f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -103,4 +103,21 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(unsafeData(2)) === data2) assert(numDataRows.value.value === data.length) } + + test("test serialization empty hash map") { + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + val hashed = new UnsafeHashedRelation( + new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + hashed.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) + + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val toUnsafe = UnsafeProjection.create(schema) + val row = toUnsafe(InternalRow(0)) + assert(hashed2.get(row) === null) + } } From b1581ac28840a4d2209ef8bb5c9f8700b4c1b286 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 11 Aug 2015 22:46:59 -0700 Subject: [PATCH 0181/1824] [SPARK-9854] [SQL] RuleExecutor.timeMap should be thread-safe `RuleExecutor.timeMap` is currently a non-thread-safe mutable HashMap; this can lead to infinite loops if multiple threads are concurrently modifying the map. I believe that this is responsible for some hangs that I've observed in HiveQuerySuite. This patch addresses this by using a Guava `AtomicLongMap`. Author: Josh Rosen Closes #8120 from JoshRosen/rule-executor-time-map-fix. --- .../spark/sql/catalyst/rules/RuleExecutor.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 8b824511a79d..f80d2a93241d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,22 +17,25 @@ package org.apache.spark.sql.catalyst.rules +import scala.collection.JavaConverters._ + +import com.google.common.util.concurrent.AtomicLongMap + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide -import scala.collection.mutable - object RuleExecutor { - protected val timeMap = new mutable.HashMap[String, Long].withDefault(_ => 0) + protected val timeMap = AtomicLongMap.create[String]() /** Resets statistics about time spent running specific rules */ def resetTime(): Unit = timeMap.clear() /** Dump statistics about time spent running specific rules. */ def dumpTimeSpent(): String = { - val maxSize = timeMap.keys.map(_.toString.length).max - timeMap.toSeq.sortBy(_._2).reverseMap { case (k, v) => + val map = timeMap.asMap().asScala + val maxSize = map.keys.map(_.toString.length).max + map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" }.mkString("\n") } @@ -79,7 +82,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime - RuleExecutor.timeMap(rule.ruleName) = RuleExecutor.timeMap(rule.ruleName) + runTime + RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { logTrace( From b85f9a242a12e8096e331fa77d5ebd16e93c844d Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 11 Aug 2015 23:19:35 -0700 Subject: [PATCH 0182/1824] [SPARK-8366] maxNumExecutorsNeeded should properly handle failed tasks Author: xutingjun Author: meiyoula <1039320815@qq.com> Closes #6817 from XuTingjun/SPARK-8366. --- .../spark/ExecutorAllocationManager.scala | 22 ++++++++++++------- .../ExecutorAllocationManagerSuite.scala | 22 +++++++++++++++++-- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1877aaf2cac5..b93536e6536e 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -599,14 +599,8 @@ private[spark] class ExecutorAllocationManager( // If this is the last pending task, mark the scheduler queue as empty stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex - val numTasksScheduled = stageIdToTaskIndices(stageId).size - val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) - if (numTasksScheduled == numTasksTotal) { - // No more pending tasks for this stage - stageIdToNumTasks -= stageId - if (stageIdToNumTasks.isEmpty) { - allocationManager.onSchedulerQueueEmpty() - } + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerQueueEmpty() } // Mark the executor on which this task is scheduled as busy @@ -618,6 +612,8 @@ private[spark] class ExecutorAllocationManager( override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { val executorId = taskEnd.taskInfo.executorId val taskId = taskEnd.taskInfo.taskId + val taskIndex = taskEnd.taskInfo.index + val stageId = taskEnd.stageId allocationManager.synchronized { numRunningTasks -= 1 // If the executor is no longer running any scheduled tasks, mark it as idle @@ -628,6 +624,16 @@ private[spark] class ExecutorAllocationManager( allocationManager.onExecutorIdle(executorId) } } + + // If the task failed, we expect it to be resubmitted later. To ensure we have + // enough resources to run the resubmitted task, we need to mark the scheduler + // as backlogged again if it's not already marked as such (SPARK-8366) + if (taskEnd.reason != Success) { + if (totalPendingTasks() == 0) { + allocationManager.onSchedulerBacklogged() + } + stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) } + } } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 34caca892891..f374f97f8744 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -206,8 +206,8 @@ class ExecutorAllocationManagerSuite val task2Info = createTaskInfo(1, 0, "executor-1") sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task1Info, null)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, null, task2Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) + sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) assert(adjustRequestedExecutors(manager) === -1) } @@ -787,6 +787,24 @@ class ExecutorAllocationManagerSuite Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) } + test("SPARK-8366: maxNumExecutorsNeeded should properly handle failed tasks") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(maxNumExecutorsNeeded(manager) === 0) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + assert(maxNumExecutorsNeeded(manager) === 1) + + val taskInfo = createTaskInfo(1, 1, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + assert(maxNumExecutorsNeeded(manager) === 1) + + // If the task is failed, we expect it to be resubmitted later. + val taskEndReason = ExceptionFailure(null, null, null, null, null) + sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, From a807fcbe50b2ce18751d80d39e9d21842f7da32a Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Tue, 11 Aug 2015 23:20:39 -0700 Subject: [PATCH 0183/1824] [SPARK-9806] [WEB UI] Don't share ReplayListenerBus between multiple applications Author: Rohit Agarwal Closes #8088 from mindprince/SPARK-9806. --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e3060ac3fa1a..53c18ca3ff50 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -272,9 +272,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Replay the log files in the list and merge the list of old applications with new ones */ private def mergeApplicationListing(logs: Seq[FileStatus]): Unit = { - val bus = new ReplayListenerBus() val newAttempts = logs.flatMap { fileStatus => try { + val bus = new ReplayListenerBus() val res = replay(fileStatus, bus) res match { case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") From 4e3f4b934f74e8c7c06f4940d6381343f9fd4918 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 11 Aug 2015 23:23:17 -0700 Subject: [PATCH 0184/1824] [SPARK-9829] [WEBUI] Display the update value for peak execution memory The peak execution memory is not correct because it shows the sum of finished tasks' values when a task finishes. This PR fixes it by using the update value rather than the accumulator value. Author: zsxwing Closes #8121 from zsxwing/SPARK-9829. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0c94204df653..fb4556b83685 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -860,7 +860,7 @@ private[ui] class TaskDataSource( } val peakExecutionMemoryUsed = taskInternalAccumulables .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.value.toLong } + .map { acc => acc.update.getOrElse("0").toLong } .getOrElse(0L) val maybeInput = metrics.flatMap(_.inputMetrics) From bab89232854de7554e88f29cab76f1a1c349edc1 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Tue, 11 Aug 2015 23:25:02 -0700 Subject: [PATCH 0185/1824] [SPARK-9426] [WEBUI] Job page DAG visualization is not shown To reproduce the issue, go to the stage page and click DAG Visualization once, then go to the job page to show the job DAG visualization. You will only see the first stage of the job. Root cause: the java script use local storage to remember your selection. Once you click the stage DAG visualization, the local storage set `expand-dag-viz-arrow-stage` to true. When you go to the job page, the js checks `expand-dag-viz-arrow-stage` in the local storage first and will try to show stage DAG visualization on the job page. To fix this, I set an id to the DAG span to differ job page and stage page. In the js code, we check the id and local storage together to make sure we show the correct DAG visualization. Author: Carson Wang Closes #8104 from carsonwang/SPARK-9426. --- .../resources/org/apache/spark/ui/static/spark-dag-viz.js | 8 ++++---- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 4a893bc0189a..83dbea40b63f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -109,13 +109,13 @@ function toggleDagViz(forJob) { } $(function (){ - if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { + if ($("#stage-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { // Set it to false so that the click function can revert it window.localStorage.setItem(expandDagVizArrowKey(false), "false"); toggleDagViz(false); - } - - if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { + } else if ($("#job-dag-viz").length && + window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { // Set it to false so that the click function can revert it window.localStorage.setItem(expandDagVizArrowKey(true), "false"); toggleDagViz(true); diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 718aea7e1dc2..f2da41772410 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -352,7 +352,8 @@ private[spark] object UIUtils extends Logging { */ private def showDagViz(graphs: Seq[RDDOperationGraph], forJob: Boolean): Seq[Node] = {
    - + From 5c99d8bf98cbf7f568345d02a814fc318cbfca75 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 11 Aug 2015 23:26:33 -0700 Subject: [PATCH 0186/1824] [SPARK-8798] [MESOS] Allow additional uris to be fetched with mesos Some users like to download additional files in their sandbox that they can refer to from their spark program, or even later mount these files to another directory. Author: Timothy Chen Closes #7195 from tnachen/mesos_files. --- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 5 +++++ .../scheduler/cluster/mesos/MesosClusterScheduler.scala | 3 +++ .../scheduler/cluster/mesos/MesosSchedulerBackend.scala | 5 +++++ .../scheduler/cluster/mesos/MesosSchedulerUtils.scala | 6 ++++++ docs/running-on-mesos.md | 8 ++++++++ 5 files changed, 27 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 15a0915708c7..d6e1e9e5bebc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -194,6 +194,11 @@ private[spark] class CoarseMesosSchedulerBackend( s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } + + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + command.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index f078547e7135..64ec2b8e3db1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -403,6 +403,9 @@ private[spark] class MesosClusterScheduler( } builder.setValue(s"$executable $cmdOptions $jar $appArguments") builder.setEnvironment(envBuilder.build()) + conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } builder.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 3f63ec1c5832..5c20606d5871 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -133,6 +133,11 @@ private[spark] class MesosSchedulerBackend( builder.addAllResources(usedCpuResources) builder.addAllResources(usedMemResources) + + sc.conf.getOption("spark.mesos.uris").map { uris => + setupUris(uris, command) + } + val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index c04920e4f587..5b854aa5c275 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -331,4 +331,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { sc.executorMemory } + def setupUris(uris: String, builder: CommandInfo.Builder): Unit = { + uris.split(",").foreach { uri => + builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim())) + } + } + } diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index debdd2adf22d..55e6d4e83a72 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -306,6 +306,14 @@ See the [configuration page](configuration.html) for information on Spark config the final overhead will be this value. + + spark.mesos.uris + (none) + + A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos. + This applies to both coarse-grain and fine-grain mode. + + spark.mesos.principal Framework principal to authenticate to Mesos From 741a29f98945538a475579ccc974cd42c1613be4 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 11 Aug 2015 23:33:22 -0700 Subject: [PATCH 0187/1824] [SPARK-9575] [MESOS] Add docuemntation around Mesos shuffle service. andrewor14 Author: Timothy Chen Closes #7907 from tnachen/mesos_shuffle. --- docs/running-on-mesos.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 55e6d4e83a72..cfd219ab02e2 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -216,6 +216,20 @@ node. Please refer to [Hadoop on Mesos](https://github.com/mesos/hadoop). In either case, HDFS runs separately from Hadoop MapReduce, without being scheduled through Mesos. +# Dynamic Resource Allocation with Mesos + +Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics +of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down +since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler +can scale back up to the same amount of executors when Spark signals more executors are needed. + +Users that like to utilize this feature should launch the Mesos Shuffle Service that +provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's +termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh +scripts accordingly. + +The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos +is to launch the Shuffle Service with Marathon with a unique host constraint. # Configuration From 9d0822455ddc8d765440d58c463367a4d67ef456 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 12 Aug 2015 19:54:00 +0800 Subject: [PATCH 0188/1824] [SPARK-9182] [SQL] Filters are not passed through to jdbc source This PR fixes unable to push filter down to JDBC source caused by `Cast` during pattern matching. While we are comparing columns of different type, there's a big chance we need a cast on the column, therefore not match the pattern directly on Attribute and would fail to push down. Author: Yijie Shen Closes #8049 from yjshen/jdbc_pushdown. --- .../datasources/DataSourceStrategy.scala | 30 ++++++++++++++-- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 34 +++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2a4c40db8bb6..9eea2b038253 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{InternalRow, expressions} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{TimestampType, DateType, StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -343,11 +343,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * and convert them. */ protected[sql] def selectFilters(filters: Seq[Expression]) = { + import CatalystTypeConverters._ + def translate(predicate: Expression): Option[Filter] = predicate match { case expressions.EqualTo(a: Attribute, Literal(v, _)) => Some(sources.EqualTo(a.name, v)) case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) + case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) => + Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) => + Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => Some(sources.EqualNullSafe(a.name, v)) @@ -358,21 +364,41 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => Some(sources.LessThan(a.name, v)) + case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) => + Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) => + Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThan(a: Attribute, Literal(v, _)) => Some(sources.LessThan(a.name, v)) case expressions.LessThan(Literal(v, _), a: Attribute) => Some(sources.GreaterThan(a.name, v)) + case expressions.LessThan(Cast(a: Attribute, _), l: Literal) => + Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) => + Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.GreaterThanOrEqual(a.name, v)) case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.LessThanOrEqual(a.name, v)) + case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) => + Some(sources.GreaterThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) => + Some(sources.LessThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.LessThanOrEqual(a.name, v)) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) => + Some(sources.LessThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) => + Some(sources.GreaterThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.InSet(a: Attribute, set) => Some(sources.In(a.name, set.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 8eab6a0adccc..281943e23fcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -284,7 +284,7 @@ private[sql] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = { + val filterWhereClause: String = { val filterStrings = filters map compileFilter filter (_ != null) if (filterStrings.size > 0) { val sb = new StringBuilder("WHERE ") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 42f2449afb0f..b9cfae51e809 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,6 +25,8 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -148,6 +150,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b DECIMAL(4, 0))"). + executeUpdate() + conn.prepareStatement("insert into test.decimals values (12345.67, 1234)").executeUpdate() + conn.prepareStatement("insert into test.decimals values (34567.89, 1428)").executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE decimals + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), @@ -445,4 +459,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } + test("SPARK-9182: filters are not passed through to jdbc source") { + def checkPushedFilter(query: String, filterStr: String): Unit = { + val rddOpt = sql(query).queryExecution.executedPlan.collectFirst { + case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd + } + assert(rddOpt.isDefined) + val pushedFilterStr = rddOpt.get.filterWhereClause + assert(pushedFilterStr.contains(filterStr), + s"Expected to push [$filterStr], actually we pushed [$pushedFilterStr]") + } + + checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 'fred'") + checkPushedFilter("select * from inttypes where A > '15'", "A > 15") + checkPushedFilter("select * from inttypes where C <= 20", "C <= 20") + checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00") + checkPushedFilter("select * from decimals where A > 1000 AND A < 2000", + "A > 1000.00 AND A < 2000.00") + checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 2000.00 AND B > 20") + checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 1998-09-10") + } } From 3ecb3794302dc12d0989f8d725483b2cc37762cf Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 12 Aug 2015 20:01:34 +0800 Subject: [PATCH 0189/1824] [SPARK-9407] [SQL] Relaxes Parquet ValidTypeMap to allow ENUM predicates to be pushed down This PR adds a hacky workaround for PARQUET-201, and should be removed once we upgrade to parquet-mr 1.8.1 or higher versions. In Parquet, not all types of columns can be used for filter push-down optimization. The set of valid column types is controlled by `ValidTypeMap`. Unfortunately, in parquet-mr 1.7.0 and prior versions, this limitation is too strict, and doesn't allow `BINARY (ENUM)` columns to be pushed down. On the other hand, `BINARY (ENUM)` is commonly seen in Parquet files written by libraries like `parquet-avro`. This restriction is problematic for Spark SQL, because Spark SQL doesn't have a type that maps to Parquet `BINARY (ENUM)` directly, and always converts `BINARY (ENUM)` to Catalyst `StringType`. Thus, a predicate involving a `BINARY (ENUM)` is recognized as one involving a string field instead and can be pushed down by the query optimizer. Such predicates are actually perfectly legal except that it fails the `ValidTypeMap` check. The workaround added here is relaxing `ValidTypeMap` to include `BINARY (ENUM)`. I also took the chance to simplify `ParquetCompatibilityTest` a little bit when adding regression test. Author: Cheng Lian Closes #8107 from liancheng/spark-9407/parquet-enum-filter-push-down. --- .../datasources/parquet/ParquetFilters.scala | 38 ++++- .../datasources/parquet/ParquetRelation.scala | 2 +- sql/core/src/test/README.md | 16 +- sql/core/src/test/avro/parquet-compat.avdl | 13 +- sql/core/src/test/avro/parquet-compat.avpr | 13 +- .../parquet/test/avro/CompatibilityTest.java | 2 +- .../datasources/parquet/test/avro/Nested.java | 4 +- .../parquet/test/avro/ParquetAvroCompat.java | 4 +- .../parquet/test/avro/ParquetEnum.java | 142 ++++++++++++++++++ .../datasources/parquet/test/avro/Suit.java | 13 ++ .../ParquetAvroCompatibilitySuite.scala | 105 +++++++------ .../parquet/ParquetCompatibilityTest.scala | 33 +--- .../test/scripts/{gen-code.sh => gen-avro.sh} | 13 +- sql/core/src/test/scripts/gen-thrift.sh | 27 ++++ .../src/test/thrift/parquet-compat.thrift | 2 +- .../hive/ParquetHiveCompatibilitySuite.scala | 83 +++++----- 16 files changed, 374 insertions(+), 136 deletions(-) create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java rename sql/core/src/test/scripts/{gen-code.sh => gen-avro.sh} (76%) create mode 100755 sql/core/src/test/scripts/gen-thrift.sh diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 9e2e232f5016..63915e0a2865 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -25,9 +25,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.compat.FilterCompat._ import org.apache.parquet.filter2.predicate.FilterApi._ -import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Statistics} -import org.apache.parquet.filter2.predicate.UserDefinedPredicate +import org.apache.parquet.filter2.predicate._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.OriginalType +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ @@ -197,6 +198,8 @@ private[sql] object ParquetFilters { def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + relaxParquetValidTypeMap + // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, @@ -239,6 +242,37 @@ private[sql] object ParquetFilters { } } + // !! HACK ALERT !! + // + // This lazy val is a workaround for PARQUET-201, and should be removed once we upgrade to + // parquet-mr 1.8.1 or higher versions. + // + // In Parquet, not all types of columns can be used for filter push-down optimization. The set + // of valid column types is controlled by `ValidTypeMap`. Unfortunately, in parquet-mr 1.7.0 and + // prior versions, the limitation is too strict, and doesn't allow `BINARY (ENUM)` columns to be + // pushed down. + // + // This restriction is problematic for Spark SQL, because Spark SQL doesn't have a type that maps + // to Parquet original type `ENUM` directly, and always converts `ENUM` to `StringType`. Thus, + // a predicate involving a `ENUM` field can be pushed-down as a string column, which is perfectly + // legal except that it fails the `ValidTypeMap` check. + // + // Here we add `BINARY (ENUM)` into `ValidTypeMap` lazily via reflection to workaround this issue. + private lazy val relaxParquetValidTypeMap: Unit = { + val constructor = Class + .forName(classOf[ValidTypeMap].getCanonicalName + "$FullTypeDescriptor") + .getDeclaredConstructor(classOf[PrimitiveTypeName], classOf[OriginalType]) + + constructor.setAccessible(true) + val enumTypeDescriptor = constructor + .newInstance(PrimitiveTypeName.BINARY, OriginalType.ENUM) + .asInstanceOf[AnyRef] + + val addMethod = classOf[ValidTypeMap].getDeclaredMethods.find(_.getName == "add").get + addMethod.setAccessible(true) + addMethod.invoke(null, classOf[Binary], enumTypeDescriptor) + } + /** * Converts Catalyst predicate expressions to Parquet filter predicates. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index c71c69b6e80b..52fac18ba187 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -678,7 +678,7 @@ private[sql] object ParquetRelation extends Logging { val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) - // HACK ALERT: + // !! HACK ALERT !! // // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` diff --git a/sql/core/src/test/README.md b/sql/core/src/test/README.md index 3dd9861b4896..421c2ea4f7ae 100644 --- a/sql/core/src/test/README.md +++ b/sql/core/src/test/README.md @@ -6,23 +6,19 @@ The following directories and files are used for Parquet compatibility tests: . ├── README.md # This file ├── avro -│   ├── parquet-compat.avdl # Testing Avro IDL -│   └── parquet-compat.avpr # !! NO TOUCH !! Protocol file generated from parquet-compat.avdl +│   ├── *.avdl # Testing Avro IDL(s) +│   └── *.avpr # !! NO TOUCH !! Protocol files generated from Avro IDL(s) ├── gen-java # !! NO TOUCH !! Generated Java code ├── scripts -│   └── gen-code.sh # Script used to generate Java code for Thrift and Avro +│   ├── gen-avro.sh # Script used to generate Java code for Avro +│   └── gen-thrift.sh # Script used to generate Java code for Thrift └── thrift - └── parquet-compat.thrift # Testing Thrift schema + └── *.thrift # Testing Thrift schema(s) ``` -Generated Java code are used in the following test suites: - -- `org.apache.spark.sql.parquet.ParquetAvroCompatibilitySuite` -- `org.apache.spark.sql.parquet.ParquetThriftCompatibilitySuite` - To avoid code generation during build time, Java code generated from testing Thrift schema and Avro IDL are also checked in. -When updating the testing Thrift schema and Avro IDL, please run `gen-code.sh` to update all the generated Java code. +When updating the testing Thrift schema and Avro IDL, please run `gen-avro.sh` and `gen-thrift.sh` accordingly to update generated Java code. ## Prerequisites diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl index 24729f6143e6..8070d0a9170a 100644 --- a/sql/core/src/test/avro/parquet-compat.avdl +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -16,8 +16,19 @@ */ // This is a test protocol for testing parquet-avro compatibility. -@namespace("org.apache.spark.sql.parquet.test.avro") +@namespace("org.apache.spark.sql.execution.datasources.parquet.test.avro") protocol CompatibilityTest { + enum Suit { + SPADES, + HEARTS, + DIAMONDS, + CLUBS + } + + record ParquetEnum { + Suit suit; + } + record Nested { array nested_ints_column; string nested_string_column; diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr index a83b7c990dd2..060391765034 100644 --- a/sql/core/src/test/avro/parquet-compat.avpr +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -1,7 +1,18 @@ { "protocol" : "CompatibilityTest", - "namespace" : "org.apache.spark.sql.parquet.test.avro", + "namespace" : "org.apache.spark.sql.execution.datasources.parquet.test.avro", "types" : [ { + "type" : "enum", + "name" : "Suit", + "symbols" : [ "SPADES", "HEARTS", "DIAMONDS", "CLUBS" ] + }, { + "type" : "record", + "name" : "ParquetEnum", + "fields" : [ { + "name" : "suit", + "type" : "Suit" + } ] + }, { "type" : "record", "name" : "Nested", "fields" : [ { diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java index 70dec1a9d3c9..2368323cb36b 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -8,7 +8,7 @@ @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public interface CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"types\":[{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"types\":[{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},{\"type\":\"record\",\"name\":\"ParquetEnum\",\"fields\":[{\"name\":\"suit\",\"type\":\"Suit\"}]},{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); @SuppressWarnings("all") public interface Callback extends CompatibilityTest { diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java index a0a406bcd10c..a7bf4841919c 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Nested.java @@ -3,11 +3,11 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.execution.datasources.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class Nested extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"Nested\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } @Deprecated public java.util.List nested_ints_column; @Deprecated public java.lang.String nested_string_column; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java index 6198b00b1e3c..681cacbd12c7 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -3,11 +3,11 @@ * * DO NOT EDIT DIRECTLY */ -package org.apache.spark.sql.execution.datasources.parquet.test.avro; +package org.apache.spark.sql.execution.datasources.parquet.test.avro; @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } @Deprecated public boolean bool_column; @Deprecated public int int_column; diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java new file mode 100644 index 000000000000..05fefe4cee75 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetEnum.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class ParquetEnum extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetEnum\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"suit\",\"type\":{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public ParquetEnum() {} + + /** + * All-args constructor. + */ + public ParquetEnum(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit) { + this.suit = suit; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return suit; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: suit = (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'suit' field. + */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** + * Sets the value of the 'suit' field. + * @param value the value to set. + */ + public void setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + this.suit = value; + } + + /** Creates a new ParquetEnum RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** Creates a new ParquetEnum RecordBuilder by copying an existing ParquetEnum instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder(other); + } + + /** + * RecordBuilder for ParquetEnum instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit suit; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder other) { + super(other); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing ParquetEnum instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.SCHEMA$); + if (isValidValue(fields()[0], other.suit)) { + this.suit = data().deepCopy(fields()[0].schema(), other.suit); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit getSuit() { + return suit; + } + + /** Sets the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder setSuit(org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit value) { + validate(fields()[0], value); + this.suit = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'suit' field has been set */ + public boolean hasSuit() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'suit' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetEnum.Builder clearSuit() { + suit = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public ParquetEnum build() { + try { + ParquetEnum record = new ParquetEnum(); + record.suit = fieldSetFlags()[0] ? this.suit : (org.apache.spark.sql.execution.datasources.parquet.test.avro.Suit) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java new file mode 100644 index 000000000000..00711a0c2a26 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/Suit.java @@ -0,0 +1,13 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public enum Suit { + SPADES, HEARTS, DIAMONDS, CLUBS ; + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"Suit\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 4d9c07bb7a57..866a975ad540 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -22,10 +22,12 @@ import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ +import org.apache.avro.Schema +import org.apache.avro.generic.IndexedRecord import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter -import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat} +import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{Row, SQLContext} @@ -34,52 +36,55 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { override val sqlContext: SQLContext = TestSQLContext - override protected def beforeAll(): Unit = { - super.beforeAll() - - val writer = - new AvroParquetWriter[ParquetAvroCompat]( - new Path(parquetStore.getCanonicalPath), - ParquetAvroCompat.getClassSchema) - - (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) - writer.close() + private def withWriter[T <: IndexedRecord] + (path: String, schema: Schema) + (f: AvroParquetWriter[T] => Unit) = { + val writer = new AvroParquetWriter[T](new Path(path), schema) + try f(writer) finally writer.close() } test("Read Parquet file generated by parquet-avro") { - logInfo( - s"""Schema of the Parquet file written by parquet-avro: - |${readParquetSchema(parquetStore.getCanonicalPath)} + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetAvroCompat](path, ParquetAvroCompat.getClassSchema) { writer => + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + } + + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(path)} """.stripMargin) - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) - - Row( - i % 2 == 0, - i, - i.toLong * 10, - i.toFloat + 0.1f, - i.toDouble + 0.2d, - s"val_$i".getBytes, - s"val_$i", - - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i: Integer), - nullable(i.toLong: java.lang.Long), - nullable(i.toFloat + 0.1f: java.lang.Float), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(s"val_$i".getBytes), - nullable(s"val_$i"), - - Seq.tabulate(3)(n => s"arr_${i + n}"), - Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, - Seq.tabulate(3) { n => - (i + n).toString -> Seq.tabulate(3) { m => - Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") - } - }.toMap) - }) + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes, + s"val_$i", + + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i: Integer), + nullable(i.toLong: java.lang.Long), + nullable(i.toFloat + 0.1f: java.lang.Float), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(s"val_$i".getBytes), + nullable(s"val_$i"), + + Seq.tabulate(3)(n => s"arr_${i + n}"), + Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } } def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { @@ -122,4 +127,20 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { .build() } + + test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") { + import sqlContext.implicits._ + + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetEnum](path, ParquetEnum.getClassSchema) { writer => + (0 until 4).foreach { i => + writer.write(ParquetEnum.newBuilder().setSuit(Suit.values.apply(i)).build()) + } + } + + checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 68f35b1f3aa8..0ea64aa2a509 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -16,45 +16,28 @@ */ package org.apache.spark.sql.execution.datasources.parquet -import java.io.File import scala.collection.JavaConversions._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.QueryTest -import org.apache.spark.util.Utils abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { - protected var parquetStore: File = _ - - /** - * Optional path to a staging subdirectory which may be created during query processing - * (Hive does this). - * Parquet files under this directory will be ignored in [[readParquetSchema()]] - * @return an optional staging directory to ignore when scanning for parquet files. - */ - protected def stagingDir: Option[String] = None - - override protected def beforeAll(): Unit = { - parquetStore = Utils.createTempDir(namePrefix = "parquet-compat_") - parquetStore.delete() - } - - override protected def afterAll(): Unit = { - Utils.deleteRecursively(parquetStore) + def readParquetSchema(path: String): MessageType = { + readParquetSchema(path, { path => !path.getName.startsWith("_") }) } - def readParquetSchema(path: String): MessageType = { + def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { val fsPath = new Path(path) val fs = fsPath.getFileSystem(configuration) - val parquetFiles = fs.listStatus(fsPath).toSeq.filterNot { status => - status.getPath.getName.startsWith("_") || - stagingDir.map(status.getPath.getName.startsWith).getOrElse(false) - } + val parquetFiles = fs.listStatus(fsPath, new PathFilter { + override def accept(path: Path): Boolean = pathFilter(path) + }).toSeq + val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) footers.head.getParquetMetadata.getFileMetaData.getSchema } diff --git a/sql/core/src/test/scripts/gen-code.sh b/sql/core/src/test/scripts/gen-avro.sh similarity index 76% rename from sql/core/src/test/scripts/gen-code.sh rename to sql/core/src/test/scripts/gen-avro.sh index 5d8d8ad08555..48174b287fd7 100755 --- a/sql/core/src/test/scripts/gen-code.sh +++ b/sql/core/src/test/scripts/gen-avro.sh @@ -22,10 +22,9 @@ cd - rm -rf $BASEDIR/gen-java mkdir -p $BASEDIR/gen-java -thrift\ - --gen java\ - -out $BASEDIR/gen-java\ - $BASEDIR/thrift/parquet-compat.thrift - -avro-tools idl $BASEDIR/avro/parquet-compat.avdl > $BASEDIR/avro/parquet-compat.avpr -avro-tools compile -string protocol $BASEDIR/avro/parquet-compat.avpr $BASEDIR/gen-java +for input in `ls $BASEDIR/avro/*.avdl`; do + filename=$(basename "$input") + filename="${filename%.*}" + avro-tools idl $input> $BASEDIR/avro/${filename}.avpr + avro-tools compile -string protocol $BASEDIR/avro/${filename}.avpr $BASEDIR/gen-java +done diff --git a/sql/core/src/test/scripts/gen-thrift.sh b/sql/core/src/test/scripts/gen-thrift.sh new file mode 100755 index 000000000000..ada432c68ab9 --- /dev/null +++ b/sql/core/src/test/scripts/gen-thrift.sh @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd $(dirname $0)/.. +BASEDIR=`pwd` +cd - + +rm -rf $BASEDIR/gen-java +mkdir -p $BASEDIR/gen-java + +for input in `ls $BASEDIR/thrift/*.thrift`; do + thrift --gen java -out $BASEDIR/gen-java $input +done diff --git a/sql/core/src/test/thrift/parquet-compat.thrift b/sql/core/src/test/thrift/parquet-compat.thrift index fa5ed8c62306..98bf778aec5d 100644 --- a/sql/core/src/test/thrift/parquet-compat.thrift +++ b/sql/core/src/test/thrift/parquet-compat.thrift @@ -15,7 +15,7 @@ * limitations under the License. */ -namespace java org.apache.spark.sql.parquet.test.thrift +namespace java org.apache.spark.sql.execution.datasources.parquet.test.thrift enum Suit { SPADES, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 80eb9f122ad9..251e0324bfa5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -32,53 +32,54 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { * Set the staging directory (and hence path to ignore Parquet files under) * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. */ - override val stagingDir: Option[String] = - Some(new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR)) + private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - override protected def beforeAll(): Unit = { - super.beforeAll() + test("Read Parquet file generated by parquet-hive") { + withTable("parquet_compat") { + withTempPath { dir => + val path = dir.getCanonicalPath - withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { - withTempTable("data") { - sqlContext.sql( - s"""CREATE TABLE parquet_compat( - | bool_column BOOLEAN, - | byte_column TINYINT, - | short_column SMALLINT, - | int_column INT, - | long_column BIGINT, - | float_column FLOAT, - | double_column DOUBLE, - | - | strings_column ARRAY, - | int_to_string_column MAP - |) - |STORED AS PARQUET - |LOCATION '${parquetStore.getCanonicalPath}' - """.stripMargin) + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + withTempTable("data") { + sqlContext.sql( + s"""CREATE TABLE parquet_compat( + | bool_column BOOLEAN, + | byte_column TINYINT, + | short_column SMALLINT, + | int_column INT, + | long_column BIGINT, + | float_column FLOAT, + | double_column DOUBLE, + | + | strings_column ARRAY, + | int_to_string_column MAP + |) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) - val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") - sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") - } - } - } + val schema = sqlContext.table("parquet_compat").schema + val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") + sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + } + } - override protected def afterAll(): Unit = { - sqlContext.sql("DROP TABLE parquet_compat") - } + val schema = readParquetSchema(path, { path => + !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) + }) - test("Read Parquet file generated by parquet-hive") { - logInfo( - s"""Schema of the Parquet file written by parquet-hive: - |${readParquetSchema(parquetStore.getCanonicalPath)} - """.stripMargin) + logInfo( + s"""Schema of the Parquet file written by parquet-hive: + |$schema + """.stripMargin) - // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. - // Have to assume all BINARY values are strings here. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(parquetStore.getCanonicalPath), makeRows) + // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. + // Have to assume all BINARY values are strings here. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { + checkAnswer(sqlContext.read.parquet(path), makeRows) + } + } } } From 2e680668f7b6fc158aa068aedd19c1878ecf759e Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 12 Aug 2015 10:06:27 -0500 Subject: [PATCH 0190/1824] [SPARK-8625] [CORE] Propagate user exceptions in tasks back to driver This allows clients to retrieve the original exception from the cause field of the SparkException that is thrown by the driver. If the original exception is not in fact Serializable then it will not be returned, but the message and stacktrace will be. (All Java Throwables implement the Serializable interface, but this is no guarantee that a particular implementation can actually be serialized.) Author: Tom White Closes #7014 from tomwhite/propagate-user-exceptions. --- .../org/apache/spark/TaskEndReason.scala | 44 ++++++++++++- .../org/apache/spark/executor/Executor.scala | 14 +++- .../apache/spark/scheduler/DAGScheduler.scala | 44 ++++++++----- .../spark/scheduler/DAGSchedulerEvent.scala | 3 +- .../spark/scheduler/TaskSetManager.scala | 12 ++-- .../org/apache/spark/util/JsonProtocol.scala | 2 +- .../ExecutorAllocationManagerSuite.scala | 2 +- .../scala/org/apache/spark/FailureSuite.scala | 66 ++++++++++++++++++- .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../spark/scheduler/TaskSetManagerSuite.scala | 5 +- .../ui/jobs/JobProgressListenerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 3 +- 12 files changed, 165 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 48fd3e7e23d5..934d00dc708b 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,6 +17,8 @@ package org.apache.spark +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -90,6 +92,10 @@ case class FetchFailed( * * `fullStackTrace` is a better representation of the stack trace because it contains the whole * stack trace including the exception and its causes + * + * `exception` is the actual exception that caused the task to fail. It may be `None` in + * the case that the exception is not in fact serializable. If a task fails more than + * once (due to retries), `exception` is that one that caused the last failure. */ @DeveloperApi case class ExceptionFailure( @@ -97,11 +103,26 @@ case class ExceptionFailure( description: String, stackTrace: Array[StackTraceElement], fullStackTrace: String, - metrics: Option[TaskMetrics]) + metrics: Option[TaskMetrics], + private val exceptionWrapper: Option[ThrowableSerializationWrapper]) extends TaskFailedReason { + /** + * `preserveCause` is used to keep the exception itself so it is available to the + * driver. This may be set to `false` in the event that the exception is not in fact + * serializable. + */ + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics, + if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None) + } + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { - this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + this(e, metrics, preserveCause = true) + } + + def exception: Option[Throwable] = exceptionWrapper.flatMap { + (w: ThrowableSerializationWrapper) => Option(w.exception) } override def toErrorString: String = @@ -127,6 +148,25 @@ case class ExceptionFailure( } } +/** + * A class for recovering from exceptions when deserializing a Throwable that was + * thrown in user task code. If the Throwable cannot be deserialized it will be null, + * but the stacktrace and message will be preserved correctly in SparkException. + */ +private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends + Serializable with Logging { + private def writeObject(out: ObjectOutputStream): Unit = { + out.writeObject(exception) + } + private def readObject(in: ObjectInputStream): Unit = { + try { + exception = in.readObject().asInstanceOf[Throwable] + } catch { + case e : Exception => log.warn("Task exception could not be deserialized", e) + } + } +} + /** * :: DeveloperApi :: * The task finished successfully, but the result was lost from the executor's block manager before diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5d78a9dc8885..42a85e42ea2b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import java.io.File +import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer @@ -305,8 +305,16 @@ private[spark] class Executor( m } } - val taskEndReason = new ExceptionFailure(t, metrics) - execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason)) + val serializedTaskEndReason = { + try { + ser.serialize(new ExceptionFailure(t, metrics)) + } catch { + case _: NotSerializableException => + // t is not serializable so just send the stacktrace + ser.serialize(new ExceptionFailure(t, metrics, false)) + } + } + execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index bb489c6b6e98..7ab5ccf50adb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -200,8 +200,8 @@ class DAGScheduler( // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit = { - eventProcessLoop.post(TaskSetFailed(taskSet, reason)) + def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { + eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } private[scheduler] @@ -677,8 +677,11 @@ class DAGScheduler( submitWaitingStages() } - private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) { - stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) } + private[scheduler] def handleTaskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } submitWaitingStages() } @@ -762,7 +765,7 @@ class DAGScheduler( } } } else { - abortStage(stage, "No active job for stage " + stage.id) + abortStage(stage, "No active job for stage " + stage.id, None) } } @@ -816,7 +819,7 @@ class DAGScheduler( case NonFatal(e) => stage.makeNewStageAttempt(partitionsToCompute.size) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -845,13 +848,13 @@ class DAGScheduler( } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => - abortStage(stage, "Task not serializable: " + e.toString) + abortStage(stage, "Task not serializable: " + e.toString, Some(e)) runningStages -= stage // Abort execution return case NonFatal(e) => - abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -878,7 +881,7 @@ class DAGScheduler( } } catch { case NonFatal(e) => - abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e)) runningStages -= stage return } @@ -1098,7 +1101,8 @@ class DAGScheduler( } if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + abortStage(failedStage, "Fetch failure will not retry stage due to testing config", + None) } else if (failedStages.isEmpty) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. @@ -1126,7 +1130,7 @@ class DAGScheduler( case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => + case exceptionFailure: ExceptionFailure => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => @@ -1235,7 +1239,10 @@ class DAGScheduler( * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - private[scheduler] def abortStage(failedStage: Stage, reason: String) { + private[scheduler] def abortStage( + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { if (!stageIdToStage.contains(failedStage.id)) { // Skip all the actions if the stage has been removed. return @@ -1244,7 +1251,7 @@ class DAGScheduler( activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) for (job <- dependentJobs) { - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) } if (dependentJobs.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -1252,8 +1259,11 @@ class DAGScheduler( } /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { - val error = new SparkException(failureReason) + private def failJobAndIndependentStages( + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { + val error = new SparkException(failureReason, exception.getOrElse(null)) var ableToCancelStages = true val shouldInterruptThread = @@ -1462,8 +1472,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => dagScheduler.handleTaskCompletion(completion) - case TaskSetFailed(taskSet, reason) => - dagScheduler.handleTaskSetFailed(taskSet, reason) + case TaskSetFailed(taskSet, reason, exception) => + dagScheduler.handleTaskSetFailed(taskSet, reason, exception) case ResubmitFailedStages => dagScheduler.resubmitFailedStages() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index a213d419cf03..f72a52e85dc1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -73,6 +73,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] -case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) + extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 82455b0426a5..818b95d67f6b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -662,7 +662,7 @@ private[spark] class TaskSetManager( val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString - reason match { + val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => logWarning(failureReason) if (!successful(index)) { @@ -671,6 +671,7 @@ private[spark] class TaskSetManager( } // Not adding to failed executors for FetchFailed. isZombie = true + None case ef: ExceptionFailure => taskMetrics = ef.metrics.orNull @@ -706,12 +707,15 @@ private[spark] class TaskSetManager( s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + s"${ef.className} (${ef.description}) [duplicate $dupCount]") } + ef.exception case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) + None case e: TaskEndReason => logError("Unknown TaskEndReason: " + e) + None } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). @@ -728,16 +732,16 @@ private[spark] class TaskSetManager( logError("Task %d in stage %s failed %d times; aborting job".format( index, taskSet.id, maxTaskFailures)) abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" - .format(index, taskSet.id, maxTaskFailures, failureReason)) + .format(index, taskSet.id, maxTaskFailures, failureReason), failureException) return } } maybeFinishTaskSet() } - def abort(message: String): Unit = sched.synchronized { + def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized { // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.dagScheduler.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message, exception) isZombie = true maybeFinishTaskSet() } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c600319d9ddb..cbc94fd6d54d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -790,7 +790,7 @@ private[spark] object JsonProtocol { val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace"). map(_.extract[String]).orNull val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson) - ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) + ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled case `executorLostFailure` => diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index f374f97f8744..116f027a0f98 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -800,7 +800,7 @@ class ExecutorAllocationManagerSuite assert(maxNumExecutorsNeeded(manager) === 1) // If the task is failed, we expect it to be resubmitted later. - val taskEndReason = ExceptionFailure(null, null, null, null, null) + val taskEndReason = ExceptionFailure(null, null, null, null, null, None) sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) assert(maxNumExecutorsNeeded(manager) === 1) } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 69cb4b44cf7e..aa50a49c5023 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import org.apache.spark.util.NonSerializable -import java.io.NotSerializableException +import java.io.{IOException, NotSerializableException, ObjectInputStream} // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully @@ -166,5 +166,69 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) } + // Run a 3-task map job in which task 1 always fails with a exception message that + // depends on the failure number, and check that we get the last failure. + test("last failure cause is sent back to driver") { + sc = new SparkContext("local[1,2]", "test") + val data = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 3) { + FailureSuiteState.tasksFailed += 1 + throw new UserException("oops", + new IllegalArgumentException("failed=" + FailureSuiteState.tasksFailed)) + } + } + x * x + } + val thrown = intercept[SparkException] { + data.collect() + } + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause.getClass === classOf[UserException]) + assert(thrown.getCause.getMessage === "oops") + assert(thrown.getCause.getCause.getClass === classOf[IllegalArgumentException]) + assert(thrown.getCause.getCause.getMessage === "failed=2") + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not serializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonSerializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonSerializableUserException")) + FailureSuiteState.clear() + } + + test("failure cause stacktrace is sent back to driver if exception is not deserializable") { + sc = new SparkContext("local", "test") + val thrown = intercept[SparkException] { + sc.makeRDD(1 to 3).foreach { _ => throw new NonDeserializableUserException } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getCause === null) + assert(thrown.getMessage.contains("NonDeserializableUserException")) + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } + +class UserException(message: String, cause: Throwable) + extends RuntimeException(message, cause) + +class NonSerializableUserException extends RuntimeException { + val nonSerializableInstanceVariable = new NonSerializable +} + +class NonDeserializableUserException extends RuntimeException { + private def readObject(in: ObjectInputStream): Unit = { + throw new IOException("Intentional exception during deserialization.") + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 86dff8fb577d..b0ca49cbea4f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -242,7 +242,7 @@ class DAGSchedulerSuite /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { - runEvent(TaskSetFailed(taskSet, message)) + runEvent(TaskSetFailed(taskSet, message, None)) } /** Sends JobCancelled to the DAG scheduler. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f7cc4bb61d57..edbdb485c5ea 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -48,7 +48,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) override def executorLost(execId: String) {} - override def taskSetFailed(taskSet: TaskSet, reason: String) { + override def taskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { taskScheduler.taskSetsFailed += taskSet.id } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 56f7b9cf1f35..b140387d309f 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val taskFailedReasons = Seq( Resubmitted, new FetchFailed(null, 0, 0, 0, "ignored"), - ExceptionFailure("Exception", "description", null, null, None), + ExceptionFailure("Exception", "description", null, null, None, None), TaskResultLost, TaskKilled, ExecutorLostFailure("0"), diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index dde95f377843..343a4139b0ca 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -163,7 +163,8 @@ class JsonProtocolSuite extends SparkFunSuite { } test("ExceptionFailure backward compatibility") { - val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None) + val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, + None, None) val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure) .removeField({ _._1 == "Full Stack Trace" }) assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent)) From be5d1912076c2ffd21ec88611e53d3b3c59b7ecc Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 09:24:50 -0700 Subject: [PATCH 0191/1824] [SPARK-9795] Dynamic allocation: avoid double counting when killing same executor twice This is based on KaiXinXiaoLei's changes in #7716. The issue is that when someone calls `sc.killExecutor("1")` on the same executor twice quickly, then the executor target will be adjusted downwards by 2 instead of 1 even though we're only actually killing one executor. In certain cases where we don't adjust the target back upwards quickly, we'll end up with jobs hanging. This is a common danger because there are many places where this is called: - `HeartbeatReceiver` kills an executor that has not been sending heartbeats - `ExecutorAllocationManager` kills an executor that has been idle - The user code might call this, which may interfere with the previous callers While it's not clear whether this fixes SPARK-9745, fixing this potential race condition seems like a strict improvement. I've added a regression test to illustrate the issue. Author: Andrew Or Closes #8078 from andrewor14/da-double-kill. --- .../CoarseGrainedSchedulerBackend.scala | 11 ++++++---- .../StandaloneDynamicAllocationSuite.scala | 20 +++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6acf8a9a5e9b..5730a87f960a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -422,16 +422,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logWarning(s"Executor to kill $id does not exist!") } + // If an executor is already pending to be removed, do not kill it again (SPARK-9795) + val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) } + executorsPendingToRemove ++= executorsToKill + // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. if (!replace) { - doRequestTotalExecutors(numExistingExecutors + numPendingExecutors - - executorsPendingToRemove.size - knownExecutors.size) + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) } - executorsPendingToRemove ++= knownExecutors - doKillExecutors(knownExecutors) + doKillExecutors(executorsToKill) } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 08c41a897a86..1f2a0f0d309c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -283,6 +283,26 @@ class StandaloneDynamicAllocationSuite assert(master.apps.head.getExecutorLimit === 1000) } + test("kill the same executor twice (SPARK-9795)") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + assert(master.apps.size === 1) + assert(master.apps.head.id === appId) + assert(master.apps.head.executors.size === 2) + assert(master.apps.head.getExecutorLimit === Int.MaxValue) + // sync executors between the Master and the driver, needed because + // the driver refuses to kill executors it does not know about + syncExecutors(sc) + // kill the same executor twice + val executors = getExecutorIds(sc) + assert(executors.size === 2) + assert(sc.killExecutor(executors.head)) + assert(sc.killExecutor(executors.head)) + assert(master.apps.head.executors.size === 1) + // The limit should not be lowered twice + assert(master.apps.head.getExecutorLimit === 1) + } + // =============================== // | Utility methods for testing | // =============================== From 66d87c1d76bea2b81993156ac1fa7dad6c312ebf Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 12 Aug 2015 09:35:32 -0700 Subject: [PATCH 0192/1824] [SPARK-7583] [MLLIB] User guide update for RegexTokenizer jira: https://issues.apache.org/jira/browse/SPARK-7583 User guide update for RegexTokenizer Author: Yuhao Yang Closes #7828 from hhbyyh/regexTokenizerDoc. --- docs/ml-features.md | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index fa0ad1f00ab1..cec2cbe67340 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -217,21 +217,32 @@ for feature in result.select("result").take(3): [Tokenization](http://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) is the process of taking text (such as a sentence) and breaking it into individual terms (usually words). A simple [Tokenizer](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) class provides this functionality. The example below shows how to split sentences into sequences of words. -Note: A more advanced tokenizer is provided via [RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer). +[RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer) allows more + advanced tokenization based on regular expression (regex) matching. + By default, the parameter "pattern" (regex, default: \\s+) is used as delimiters to split the input text. + Alternatively, users can set parameter "gaps" to false indicating the regex "pattern" denotes + "tokens" rather than splitting gaps, and find all matching occurrences as the tokenization result.
    {% highlight scala %} -import org.apache.spark.ml.feature.Tokenizer +import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} val sentenceDataFrame = sqlContext.createDataFrame(Seq( (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") )).toDF("label", "sentence") val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val wordsDataFrame = tokenizer.transform(sentenceDataFrame) -wordsDataFrame.select("words", "label").take(3).foreach(println) +val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + +val tokenized = tokenizer.transform(sentenceDataFrame) +tokenized.select("words", "label").take(3).foreach(println) +val regexTokenized = regexTokenizer.transform(sentenceDataFrame) +regexTokenized.select("words", "label").take(3).foreach(println) {% endhighlight %}
    @@ -240,6 +251,7 @@ wordsDataFrame.select("words", "label").take(3).foreach(println) import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; @@ -252,8 +264,8 @@ import org.apache.spark.sql.types.StructType; JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), @@ -267,22 +279,29 @@ for (Row r : wordsDataFrame.select("words", "label").take(3)) { for (String word : words) System.out.print(word + " "); System.out.println(); } + +RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); {% endhighlight %}
    {% highlight python %} -from pyspark.ml.feature import Tokenizer +from pyspark.ml.feature import Tokenizer, RegexTokenizer sentenceDataFrame = sqlContext.createDataFrame([ (0, "Hi I heard about Spark"), - (0, "I wish Java could use case classes"), - (1, "Logistic regression models are neat") + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") ], ["label", "sentence"]) tokenizer = Tokenizer(inputCol="sentence", outputCol="words") wordsDataFrame = tokenizer.transform(sentenceDataFrame) for words_label in wordsDataFrame.select("words", "label").take(3): print(words_label) +regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") +# alternatively, pattern="\\w+", gaps(False) {% endhighlight %}
    From e0110792ef71ebfd3727b970346a2e13695990a4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 12 Aug 2015 10:08:35 -0700 Subject: [PATCH 0193/1824] [SPARK-9747] [SQL] Avoid starving an unsafe operator in aggregation This is the sister patch to #8011, but for aggregation. In a nutshell: create the `TungstenAggregationIterator` before computing the parent partition. Internally this creates a `BytesToBytesMap` which acquires a page in the constructor as of this patch. This ensures that the aggregation operator is not starved since we reserve at least 1 page in advance. rxin yhuai Author: Andrew Or Closes #8038 from andrewor14/unsafe-starve-memory-agg. --- .../spark/unsafe/map/BytesToBytesMap.java | 34 +++++-- .../unsafe/sort/UnsafeExternalSorter.java | 9 +- .../map/AbstractBytesToBytesMapSuite.java | 11 ++- .../UnsafeFixedWidthAggregationMap.java | 7 ++ .../aggregate/TungstenAggregate.scala | 72 +++++++++------ .../TungstenAggregationIterator.scala | 88 +++++++++++-------- .../TungstenAggregationIteratorSuite.scala | 56 ++++++++++++ 7 files changed, 201 insertions(+), 76 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 85b46ec8bfae..87ed47e88c4e 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -193,6 +193,11 @@ public BytesToBytesMap( TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } allocate(initialCapacity); + + // Acquire a new page as soon as we construct the map to ensure that we have at least + // one page to work with. Otherwise, other operators in the same task may starve this + // map (SPARK-9747). + acquireNewPage(); } public BytesToBytesMap( @@ -574,16 +579,9 @@ public boolean putNewKey( final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryGranted != pageSizeBytes) { - shuffleMemoryManager.release(memoryGranted); - logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + if (!acquireNewPage()) { return false; } - MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); - dataPages.add(newPage); - pageCursor = 0; - currentDataPage = newPage; dataPage = currentDataPage; dataPageBaseObject = currentDataPage.getBaseObject(); dataPageInsertOffset = currentDataPage.getBaseOffset(); @@ -642,6 +640,24 @@ public boolean putNewKey( } } + /** + * Acquire a new page from the {@link ShuffleMemoryManager}. + * @return whether there is enough space to allocate the new page. + */ + private boolean acquireNewPage() { + final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryGranted != pageSizeBytes) { + shuffleMemoryManager.release(memoryGranted); + logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + return false; + } + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + return true; + } + /** * Allocate new data structures for this map. When calling this outside of the constructor, * make sure to keep references to the old data structures so that you can free them. @@ -748,7 +764,7 @@ public long getNumHashCollisions() { } @VisibleForTesting - int getNumDataPages() { + public int getNumDataPages() { return dataPages.size(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9601aafe5546..fc364e0a895b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -132,16 +132,15 @@ private UnsafeExternalSorter( if (existingInMemorySorter == null) { initializeForWriting(); + // Acquire a new page as soon as we construct the sorter to ensure that we have at + // least one page to work with. Otherwise, other operators in the same task may starve + // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter. + acquireNewPage(); } else { this.isInMemSorterExternal = true; this.inMemSorter = existingInMemorySorter; } - // Acquire a new page as soon as we construct the sorter to ensure that we have at - // least one page to work with. Otherwise, other operators in the same task may starve - // this sorter (SPARK-9709). - acquireNewPage(); - // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 1a79c20c3524..ab480b60adae 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -543,7 +543,7 @@ public void testPeakMemoryUsed() { Platform.LONG_ARRAY_OFFSET, 8); newPeakMemory = map.getPeakMemoryUsedBytes(); - if (i % numRecordsPerPage == 0) { + if (i % numRecordsPerPage == 0 && i > 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); } else { @@ -561,4 +561,13 @@ public void testPeakMemoryUsed() { map.free(); } } + + @Test + public void testAcquirePageInConstructor() { + final BytesToBytesMap map = new BytesToBytesMap( + taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + assertEquals(1, map.getNumDataPages()); + map.free(); + } + } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 5cce41d5a756..09511ff35f78 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,6 +19,8 @@ import java.io.IOException; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; @@ -220,6 +222,11 @@ public long getPeakMemoryUsedBytes() { return map.getPeakMemoryUsedBytes(); } + @VisibleForTesting + public int getNumDataPages() { + return map.getNumDataPages(); + } + /** * Free the memory associated with this map. This is idempotent and can be called multiple times. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 6b5935a7ce29..c40ca973796a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.rdd.RDD +import org.apache.spark.TaskContext +import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -68,35 +69,56 @@ case class TungstenAggregate( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => - val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { - // This is a grouped aggregate and the input iterator is empty, - // so return an empty iterator. - Iterator.empty.asInstanceOf[Iterator[UnsafeRow]] - } else { - val aggregationIterator = - new TungstenAggregationIterator( - groupingExpressions, - nonCompleteAggregateExpressions, - completeAggregateExpressions, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - child.output, - iter, - testFallbackStartsAt, - numInputRows, - numOutputRows) - - if (!hasInput && groupingExpressions.isEmpty) { + + /** + * Set up the underlying unsafe data structures used before computing the parent partition. + * This makes sure our iterator is not starved by other operators in the same task. + */ + def preparePartition(): TungstenAggregationIterator = { + new TungstenAggregationIterator( + groupingExpressions, + nonCompleteAggregateExpressions, + completeAggregateExpressions, + initialInputBufferOffset, + resultExpressions, + newMutableProjection, + child.output, + testFallbackStartsAt, + numInputRows, + numOutputRows) + } + + /** Compute a partition using the iterator already set up previously. */ + def executePartition( + context: TaskContext, + partitionIndex: Int, + aggregationIterator: TungstenAggregationIterator, + parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = { + val hasInput = parentIterator.hasNext + if (!hasInput) { + // We're not using the underlying map, so we just can free it here + aggregationIterator.free() + if (groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) } else { - aggregationIterator + // This is a grouped aggregate and the input iterator is empty, + // so return an empty iterator. + Iterator[UnsafeRow]() } + } else { + aggregationIterator.start(parentIterator) + aggregationIterator } } + + // Note: we need to set up the iterator in each partition before computing the + // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). + val resultRdd = { + new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator]( + child.execute(), preparePartition, executePartition, preservesPartitioning = true) + } + resultRdd.asInstanceOf[RDD[InternalRow]] } override def simpleString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 1f383dd04482..af7e0fcedbe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType * the function used to create mutable projections. * @param originalInputAttributes * attributes of representing input rows from `inputIter`. - * @param inputIter - * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], @@ -83,12 +81,14 @@ class TungstenAggregationIterator( resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[Int], numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends Iterator[UnsafeRow] with Logging { + // The parent partition iterator, to be initialized later in `start` + private[this] var inputIter: Iterator[InternalRow] = null + /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// @@ -348,11 +348,15 @@ class TungstenAggregationIterator( false // disable tracking of performance metrics ) + // Exposed for testing + private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap + // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in // hashMap. If we could not allocate more memory for the map, we switch to // sort-based aggregation (by calling switchToSortBasedAggregation). private def processInputs(): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 @@ -372,6 +376,7 @@ class TungstenAggregationIterator( // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have // been processed. private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") var i = 0 while (!sortBased && inputIter.hasNext) { val newInput = inputIter.next() @@ -412,6 +417,7 @@ class TungstenAggregationIterator( * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { + assert(inputIter != null, "attempted to process input when iterator was null") logInfo("falling back to sort based aggregation.") // Step 1: Get the ExternalSorter containing sorted entries of the map. externalSorter = hashMap.destructAndCreateExternalSorter() @@ -431,6 +437,11 @@ class TungstenAggregationIterator( case _ => false } + // Note: Since we spill the sorter's contents immediately after creating it, we must insert + // something into the sorter here to ensure that we acquire at least a page of memory. + // This is done through `externalSorter.insertKV`, which will trigger the page allocation. + // Otherwise, children operators may steal the window of opportunity and starve our sorter. + if (needsProcess) { // First, we create a buffer. val buffer = createNewAggregationBuffer() @@ -588,27 +599,33 @@ class TungstenAggregationIterator( // have not switched to sort-based aggregation. /////////////////////////////////////////////////////////////////////////// - // Starts to process input rows. - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + /** + * Start processing input rows. + * Only after this method is called will this iterator be non-empty. + */ + def start(parentIter: Iterator[InternalRow]): Unit = { + inputIter = parentIter + testFallbackStartsAt match { + case None => + processInputs() + case Some(fallbackStartsAt) => + // This is the testing path. processInputsWithControlledFallback is same as processInputs + // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows + // have been processed. + processInputsWithControlledFallback(fallbackStartsAt) + } - // If we did not switch to sort-based aggregation in processInputs, - // we pre-load the first key-value pair from the map (to make hasNext idempotent). - if (!sortBased) { - // First, set aggregationBufferMapIterator. - aggregationBufferMapIterator = hashMap.iterator() - // Pre-load the first key-value pair from the aggregationBufferMapIterator. - mapIteratorHasNext = aggregationBufferMapIterator.next() - // If the map is empty, we just free it. - if (!mapIteratorHasNext) { - hashMap.free() + // If we did not switch to sort-based aggregation in processInputs, + // we pre-load the first key-value pair from the map (to make hasNext idempotent). + if (!sortBased) { + // First, set aggregationBufferMapIterator. + aggregationBufferMapIterator = hashMap.iterator() + // Pre-load the first key-value pair from the aggregationBufferMapIterator. + mapIteratorHasNext = aggregationBufferMapIterator.next() + // If the map is empty, we just free it. + if (!mapIteratorHasNext) { + hashMap.free() + } } } @@ -673,21 +690,20 @@ class TungstenAggregationIterator( } /////////////////////////////////////////////////////////////////////////// - // Part 8: A utility function used to generate a output row when there is no - // input and there is no grouping expression. + // Part 8: Utility functions /////////////////////////////////////////////////////////////////////////// + /** + * Generate a output row when there is no input and there is no grouping expression. + */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { - if (groupingExpressions.isEmpty) { - sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) - // We create a output row and copy it. So, we can free the map. - val resultCopy = - generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() - hashMap.free() - resultCopy - } else { - throw new IllegalStateException( - "This method should not be called when groupingExpressions is not empty.") - } + assert(groupingExpressions.isEmpty) + assert(inputIter == null) + generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer) + } + + /** Free memory used in the underlying map. */ + def free(): Unit = { + hashMap.free() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala new file mode 100644 index 000000000000..ac22c2f3c0a5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.unsafe.memory.TaskMemoryManager + +class TungstenAggregationIteratorSuite extends SparkFunSuite { + + test("memory acquired on construction") { + // set up environment + val ctx = TestSQLContext + + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) + val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) + TaskContext.setTaskContext(taskContext) + + // Assert that a page is allocated before processing starts + var iter: TungstenAggregationIterator = null + try { + val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { + () => new InterpretedMutableProjection(expr, schema) + } + val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, + Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + val numPages = iter.getHashMap.getNumDataPages + assert(numPages === 1) + } finally { + // Clean up + if (iter != null) { + iter.free() + } + TaskContext.unset() + } + } +} From 57ec27dd7784ce15a2ece8a6c8ac7bd5fd25aea2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 12 Aug 2015 10:38:30 -0700 Subject: [PATCH 0194/1824] [SPARK-9804] [HIVE] Use correct value for isSrcLocal parameter. If the correct parameter is not provided, Hive will run into an error because it calls methods that are specific to the local filesystem to copy the data. Author: Marcelo Vanzin Closes #8086 from vanzin/SPARK-9804. --- .../org/apache/spark/sql/hive/client/HiveShim.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 6e826ce55220..8fc8935b1dc3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -25,7 +25,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} @@ -429,7 +429,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { isSkewedStoreAsSubdir: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, holdDDLTime: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - JBoolean.TRUE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE) } override def loadTable( @@ -439,7 +439,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { replace: Boolean, holdDDLTime: Boolean): Unit = { loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime: JBoolean, - JBoolean.TRUE, JBoolean.FALSE, JBoolean.FALSE) + isSrcLocal(loadPath, hive.getConf()): JBoolean, JBoolean.FALSE, JBoolean.FALSE) } override def loadDynamicPartitions( @@ -461,6 +461,13 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, TimeUnit.MILLISECONDS).asInstanceOf[Long] } + + protected def isSrcLocal(path: Path, conf: HiveConf): Boolean = { + val localFs = FileSystem.getLocal(conf) + val pathFs = FileSystem.get(path.toUri(), conf) + localFs.getUri() == pathFs.getUri() + } + } private[client] class Shim_v1_0 extends Shim_v0_14 { From 70fe558867ccb4bcff6ec673438b03608bb02252 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 10:48:52 -0700 Subject: [PATCH 0195/1824] [SPARK-9847] [ML] Modified copyValues to distinguish between default, explicit param values From JIRA: Currently, Params.copyValues copies default parameter values to the paramMap of the target instance, rather than the defaultParamMap. It should copy to the defaultParamMap because explicitly setting a parameter can change the semantics. This issue arose in SPARK-9789, where 2 params "threshold" and "thresholds" for LogisticRegression can have mutually exclusive values. If thresholds is set, then fit() will copy the default value of threshold as well, easily resulting in inconsistent settings for the 2 params. CC: mengxr Author: Joseph K. Bradley Closes #8115 from jkbradley/copyvalues-fix. --- .../org/apache/spark/ml/param/params.scala | 19 ++++++++++++++++--- .../apache/spark/ml/param/ParamsSuite.scala | 8 ++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d68f5ff0053c..91c0a5631319 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -559,13 +559,26 @@ trait Params extends Identifiable with Serializable { /** * Copies param values from this instance to another instance for params shared by them. - * @param to the target instance - * @param extra extra params to be copied + * + * This handles default Params and explicitly set Params separately. + * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are + * copied from and to [[paramMap]]. + * Warning: This implicitly assumes that this [[Params]] instance and the target instance + * share the same set of default Params. + * + * @param to the target instance, which should work with the same set of default Params as this + * source instance + * @param extra extra params to be copied to the target's [[paramMap]] * @return the target instance with param values copied */ protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = { - val map = extractParamMap(extra) + val map = paramMap ++ extra params.foreach { param => + // copy default Params + if (defaultParamMap.contains(param) && to.hasParam(param.name)) { + to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param)) + } + // copy explicitly set Params if (map.contains(param) && to.hasParam(param.name)) { to.set(param.name, map(param)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 050d4170ea01..be95638d8168 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -200,6 +200,14 @@ class ParamsSuite extends SparkFunSuite { val inArray = ParamValidators.inArray[Int](Array(1, 2)) assert(inArray(1) && inArray(2) && !inArray(0)) } + + test("Params.copyValues") { + val t = new TestParams() + val t2 = t.copy(ParamMap.empty) + assert(!t2.isSet(t2.maxIter)) + val t3 = t.copy(ParamMap(t.maxIter -> 20)) + assert(t3.isSet(t3.maxIter)) + } } object ParamsSuite extends SparkFunSuite { From 60103ecd3d9c92709a5878be7ebd57012813ab48 Mon Sep 17 00:00:00 2001 From: Brennan Ashton Date: Wed, 12 Aug 2015 11:57:30 -0700 Subject: [PATCH 0196/1824] [SPARK-9726] [PYTHON] PySpark DF join no longer accepts on=None rxin First pull request for Spark so let me know if I am missing anything The contribution is my original work and I license the work to the project under the project's open source license. Author: Brennan Ashton Closes #8016 from btashton/patch-1. --- python/pyspark/sql/dataframe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 47d5a6a43a84..09647ff6d074 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -566,8 +566,7 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) - - if isinstance(on[0], basestring): + elif isinstance(on[0], basestring): jdf = self._jdf.join(other._jdf, self._jseq(on)) else: assert isinstance(on[0], Column), "on should be Column or list of Column" From 762bacc16ac5e74c8b05a7c1e3e367d1d1633cef Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 12 Aug 2015 13:24:18 -0700 Subject: [PATCH 0197/1824] [SPARK-9766] [ML] [PySpark] check and add miss docs for PySpark ML Check and add miss docs for PySpark ML (this issue only check miss docs for o.a.s.ml not o.a.s.mllib). Author: Yanbo Liang Closes #8059 from yanboliang/SPARK-9766. --- python/pyspark/ml/classification.py | 12 ++++++++++-- python/pyspark/ml/clustering.py | 4 +++- python/pyspark/ml/evaluation.py | 3 ++- python/pyspark/ml/feature.py | 9 +++++---- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 5978d8f4d3a0..6702dce5545a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -34,6 +34,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): """ Logistic regression. + Currently, this class only supports binary classification. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors @@ -96,8 +97,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred # is an L2 penalty. For alpha = 1, it is an L1 penalty. self.elasticNetParam = \ Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification prediction, in range [0, 1]. @@ -656,6 +657,13 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H HasRawPredictionCol): """ Naive Bayes Classifiers. + It supports both Multinomial and Bernoulli NB. Multinomial NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html`) + can handle finitely supported discrete data. For example, by converting documents into + TF-IDF vectors, it can be used for document classification. By making every vector a + binary (0/1) data, it can also be used as Bernoulli NB + (`http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html`). + The input feature values must be nonnegative. >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b5e9b6549d9f..48338713a29e 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -37,7 +37,9 @@ def clusterCenters(self): @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): """ - K-means Clustering + K-means clustering with support for multiple parallel runs and a k-means++ like initialization + mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + they are executed together with joint passes over the data for efficiency. >>> from pyspark.mllib.linalg import Vectors >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 06e809352225..2734092575ad 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -23,7 +23,8 @@ from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator'] +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', + 'MulticlassClassificationEvaluator'] @inherit_doc diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cb4dfa21298c..535d55326646 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,10 +26,11 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', - 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', - 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', - 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] +__all__ = ['Binarizer', 'Bucketizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', + 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', + 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', + 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel', 'PCA', + 'PCAModel', 'RFormula', 'RFormulaModel'] @inherit_doc From 551def5d6972440365bd7436d484a67138d9a8f3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 14:27:13 -0700 Subject: [PATCH 0198/1824] [SPARK-9789] [ML] Added logreg threshold param back Reinstated LogisticRegression.threshold Param for binary compatibility. Param thresholds overrides threshold, if set. CC: mengxr dbtsai feynmanliang Author: Joseph K. Bradley Closes #8079 from jkbradley/logreg-reinstate-threshold. --- .../classification/LogisticRegression.scala | 127 ++++++++++++++---- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 6 +- .../JavaLogisticRegressionSuite.java | 7 +- .../LogisticRegressionSuite.scala | 33 +++-- python/pyspark/ml/classification.py | 98 +++++++++----- 6 files changed, 199 insertions(+), 76 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f55134d25885..5bcd7117b668 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -34,8 +34,7 @@ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** @@ -43,44 +42,115 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization { + with HasStandardization with HasThreshold { /** - * Version of setThresholds() for binary classification, available for backwards - * compatibility. + * Set threshold in binary classification, in range [0, 1]. * - * Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`. + * If the estimated probability of class label 1 is > threshold, then predict 1, else 0. + * A high threshold encourages the model to predict 0 more often; + * a low threshold encourages the model to predict 1 more often. + * + * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`. + * When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. + * + * Default is 0.5. + * @group setParam + */ + def setThreshold(value: Double): this.type = { + if (isSet(thresholds)) clear(thresholds) + set(threshold, value) + } + + /** + * Get threshold for binary classification. + * + * If [[threshold]] is set, returns that value. + * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification), + * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}. + * Otherwise, returns [[threshold]] default value. + * + * @group getParam + * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2. + */ + override def getThreshold: Double = { + checkThresholdConsistency() + if (isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(",")) + 1.0 / (1.0 + ts(0) / ts(1)) + } else { + $(threshold) + } + } + + /** + * Set thresholds in multiclass (or binary) classification to adjust the probability of + * predicting each class. Array must have length equal to the number of classes, with values >= 0. + * The class with largest value p/t is predicted, where p is the original probability of that + * class and t is the class' threshold. + * + * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared. + * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be + * equivalent. * - * Default is effectively 0.5. * @group setParam */ - def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value)) + def setThresholds(value: Array[Double]): this.type = { + if (isSet(threshold)) clear(threshold) + set(thresholds, value) + } /** - * Version of [[getThresholds()]] for binary classification, available for backwards - * compatibility. + * Get thresholds for binary or multiclass classification. + * + * If [[thresholds]] is set, return its value. + * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary + * classification: (1-threshold, threshold). + * If neither are set, throw an exception. * - * Param thresholds must have length 2 (or not be specified). - * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}. * @group getParam */ - def getThreshold: Double = { - if (isDefined(thresholds)) { - val thresholdValues = $(thresholds) - assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" + - " binary classification, but thresholds has length != 2." + - s" thresholds: ${thresholdValues.mkString(",")}") - 1.0 / (1.0 + thresholdValues(0) / thresholdValues(1)) + override def getThresholds: Array[Double] = { + checkThresholdConsistency() + if (!isSet(thresholds) && isSet(threshold)) { + val t = $(threshold) + Array(1-t, t) } else { - 0.5 + $(thresholds) + } + } + + /** + * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent. + * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent + */ + protected def checkThresholdConsistency(): Unit = { + if (isSet(threshold) && isSet(thresholds)) { + val ts = $(thresholds) + require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" + + s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" + + s" classification, but Param thresholds is set with length ${ts.length}." + + " Clear one Param value to fix this problem.") + val t = 1.0 / (1.0 + ts(0) / ts(1)) + require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" + + s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)") } } + + override def validateParams(): Unit = { + checkThresholdConsistency() + } } /** * :: Experimental :: * Logistic regression. - * Currently, this class only supports binary classification. + * Currently, this class only supports binary classification. It will support multiclass + * in the future. */ @Experimental class LogisticRegression(override val uid: String) @@ -128,7 +198,7 @@ class LogisticRegression(override val uid: String) * Whether to fit an intercept term. * Default is true. * @group setParam - * */ + */ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -140,7 +210,7 @@ class LogisticRegression(override val uid: String) * is applied. In R's GLMNET package, the default behavior is true as well. * Default is true. * @group setParam - * */ + */ def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) @@ -148,6 +218,10 @@ class LogisticRegression(override val uid: String) override def getThreshold: Double = super.getThreshold + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractLabeledPoints(dataset).map { @@ -314,6 +388,10 @@ class LogisticRegressionModel private[ml] ( override def getThreshold: Double = super.getThreshold + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + + override def getThresholds: Array[Double] = super.getThresholds + /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { BLAS.dot(features, weights) + intercept @@ -364,6 +442,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[thresholds]]. */ override protected def predict(features: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (score(features) > getThreshold) 1 else 0 } @@ -393,6 +472,7 @@ class LogisticRegressionModel private[ml] ( } override protected def raw2prediction(rawPrediction: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. val t = getThreshold val rawThreshold = if (t == 0.0) { Double.NegativeInfinity @@ -405,6 +485,7 @@ class LogisticRegressionModel private[ml] ( } override protected def probability2prediction(probability: Vector): Double = { + // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (probability(1) > getThreshold) 1 else 0 } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index da4c07683039..9e12f1856a94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -45,14 +45,14 @@ private[shared] object SharedParamsCodeGen { " These probabilities should be treated as confidences, not precise probabilities.", Some("\"probability\"")), ParamDesc[Double]("threshold", - "threshold in binary classification prediction, in range [0, 1]", + "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + " to adjust the probability of predicting each class." + " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original probability" + " of that class and t is the class' threshold.", - isValid = "(t: Array[Double]) => t.forall(_ >= 0)"), + isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 23e2b6cc4399..a17d4ea960a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params { } /** - * Trait for shared param threshold. + * Trait for shared param threshold (default: 0.5). */ private[ml] trait HasThreshold extends Params { @@ -149,6 +149,8 @@ private[ml] trait HasThreshold extends Params { */ final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) + setDefault(threshold, 0.5) + /** @group getParam */ def getThreshold: Double = $(threshold) } @@ -165,7 +167,7 @@ private[ml] trait HasThresholds extends Params { final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0)) /** @group getParam */ - final def getThresholds: Array[Double] = $(thresholds) + def getThresholds: Array[Double] = $(thresholds) } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 7e9aa383728f..618b95b9bd12 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -100,9 +100,7 @@ public void logisticRegressionWithSetters() { assert(r.getDouble(0) == 0.0); } // Call transform with params, and check that the params worked. - double[] thresholds = {1.0, 0.0}; - model.transform( - dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) + model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; @@ -112,9 +110,8 @@ public void logisticRegressionWithSetters() { assert(foundNonZero); // Call fit() with new params, and check as many params as we can. - double[] thresholds2 = {0.6, 0.4}; LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb")); + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8c3d4590f5ae..e354e161c6de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -94,12 +94,13 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("setThreshold, getThreshold") { val lr = new LogisticRegression // default - withClue("LogisticRegression should not have thresholds set by default") { - intercept[java.util.NoSuchElementException] { + assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5") + withClue("LogisticRegression should not have thresholds set by default.") { + intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future lr.getThresholds } } - // Set via thresholds. + // Set via threshold. // Intuition: Large threshold or large thresholds(1) makes class 0 more likely. lr.setThreshold(1.0) assert(lr.getThresholds === Array(0.0, 1.0)) @@ -107,10 +108,26 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lr.getThresholds === Array(1.0, 0.0)) lr.setThreshold(0.5) assert(lr.getThresholds === Array(0.5, 0.5)) - // Test getThreshold - lr.setThresholds(Array(0.3, 0.7)) + // Set via thresholds + val lr2 = new LogisticRegression + lr2.setThresholds(Array(0.3, 0.7)) val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7) - assert(lr.getThreshold ~== expectedThreshold relTol 1E-7) + assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7) + // thresholds and threshold must be consistent + lr2.setThresholds(Array(0.1, 0.2, 0.3)) + withClue("getThreshold should throw error if thresholds has length != 2.") { + intercept[IllegalArgumentException] { + lr2.getThreshold + } + } + // thresholds and threshold must be consistent: values + withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { + intercept[IllegalArgumentException] { + val lr2model = lr2.fit(dataset, + lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0)) + lr2model.getThreshold + } + } } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -145,7 +162,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.thresholds -> Array(1.0, 0.0), + model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() @@ -153,8 +170,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predNotAllZero.exists(_ !== 0.0)) // Call fit() with new params, and check as many params as we can. + lr.setThresholds(Array(0.6, 0.4)) val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, - lr.thresholds -> Array(0.6, 0.4), lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 6702dce5545a..83f808efc3bf 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -76,19 +76,21 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti " Array must have length equal to the number of classes, with values >= 0." + " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") + threshold = Param(Params._dummy(), "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -101,7 +103,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") - #: param for threshold in binary classification prediction, in range [0, 1]. + #: param for threshold in binary classification, in range [0, 1]. + self.threshold = Param(self, "threshold", + "Threshold in binary classification prediction, in range [0, 1]." + + " If threshold and thresholds are both set, they must match.") + #: param for thresholds or cutoffs in binary or multiclass classification self.thresholds = \ Param(self, "thresholds", "Thresholds in multi-class classification" + @@ -110,29 +116,28 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred " The class with largest value p/t is predicted, where p is the original" + " probability of that class and t is the class' threshold.") self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True) + fitIntercept=True, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) + self._checkThresholdConsistency() @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=None, thresholds=None, + threshold=0.5, thresholds=None, probabilityCol="probability", rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=None, thresholds=None, \ + threshold=0.5, thresholds=None, \ probabilityCol="probability", rawPredictionCol="rawPrediction") Sets params for logistic regression. - Param thresholds overrides Param threshold; threshold is provided - for backwards compatibility and only applies to binary classification. + If the threshold and thresholds Params are both set, they must be equivalent. """ - # Under the hood we use thresholds so translate threshold to thresholds if applicable - if thresholds is None and threshold is not None: - kwargs[thresholds] = [1-threshold, threshold] kwargs = self.setParams._input_kwargs - return self._set(**kwargs) + self._set(**kwargs) + self._checkThresholdConsistency() + return self def _create_model(self, java_model): return LogisticRegressionModel(java_model) @@ -165,44 +170,65 @@ def getFitIntercept(self): def setThreshold(self, value): """ - Sets the value of :py:attr:`thresholds` using [1-value, value]. + Sets the value of :py:attr:`threshold`. + Clears value of :py:attr:`thresholds` if it has been set. + """ + self._paramMap[self.threshold] = value + if self.isSet(self.thresholds): + del self._paramMap[self.thresholds] + return self - >>> lr = LogisticRegression() - >>> lr.getThreshold() - 0.5 - >>> lr.setThreshold(0.6) - LogisticRegression_... - >>> abs(lr.getThreshold() - 0.6) < 1e-5 - True + def getThreshold(self): + """ + Gets the value of threshold or its default value. """ - return self.setThresholds([1-value, value]) + self._checkThresholdConsistency() + if self.isSet(self.thresholds): + ts = self.getOrDefault(self.thresholds) + if len(ts) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(ts)) + return 1.0/(1.0 + ts[0]/ts[1]) + else: + return self.getOrDefault(self.threshold) def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. + Clears value of :py:attr:`threshold` if it has been set. """ self._paramMap[self.thresholds] = value + if self.isSet(self.threshold): + del self._paramMap[self.threshold] return self def getThresholds(self): """ - Gets the value of thresholds or its default value. + If :py:attr:`thresholds` is set, return its value. + Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary + classification: (1-threshold, threshold). + If neither are set, throw an error. """ - return self.getOrDefault(self.thresholds) + self._checkThresholdConsistency() + if not self.isSet(self.thresholds) and self.isSet(self.threshold): + t = self.getOrDefault(self.threshold) + return [1.0-t, t] + else: + return self.getOrDefault(self.thresholds) - def getThreshold(self): - """ - Gets the value of threshold or its default value. - """ - if self.isDefined(self.thresholds): - thresholds = self.getOrDefault(self.thresholds) - if len(thresholds) != 2: + def _checkThresholdConsistency(self): + if self.isSet(self.threshold) and self.isSet(self.thresholds): + ts = self.getParam(self.thresholds) + if len(ts) != 2: raise ValueError("Logistic Regression getThreshold only applies to" + " binary classification, but thresholds has length != 2." + - " thresholds: " + ",".join(thresholds)) - return 1.0/(1.0+thresholds[0]/thresholds[1]) - else: - return 0.5 + " thresholds: " + ",".join(ts)) + t = 1.0/(1.0 + ts[0]/ts[1]) + t2 = self.getParam(self.threshold) + if abs(t2 - t) >= 1E-5: + raise ValueError("Logistic Regression getThreshold found inconsistent values for" + + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) class LogisticRegressionModel(JavaModel): From 6f60298b1d7aa97268a42eca1e3b4851a7e88cb5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 14:28:23 -0700 Subject: [PATCH 0199/1824] [SPARK-8967] [DOC] add Since annotation Add `Since` as a Scala annotation. The benefit is that we can use it without having explicit JavaDoc. This is useful for inherited methods. The limitation is that is doesn't show up in the generated Java API documentation. This might be fixed by modifying genjavadoc. I think we could leave it as a TODO. This is how the generated Scala doc looks: `since` JavaDoc tag: ![screen shot 2015-08-11 at 10 00 37 pm](https://cloud.githubusercontent.com/assets/829644/9230761/fa72865c-40d8-11e5-807e-0f3c815c5acd.png) `Since` annotation: ![screen shot 2015-08-11 at 10 00 28 pm](https://cloud.githubusercontent.com/assets/829644/9230764/0041d7f4-40d9-11e5-8124-c3f3e5d5b31f.png) rxin Author: Xiangrui Meng Closes #8131 from mengxr/SPARK-8967. --- .../org/apache/spark/annotation/Since.scala | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/annotation/Since.scala diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/core/src/main/scala/org/apache/spark/annotation/Since.scala new file mode 100644 index 000000000000..fa59393c2247 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/Since.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation + +import scala.annotation.StaticAnnotation + +/** + * A Scala annotation that specifies the Spark version when a definition was added. + * Different from the `@since` tag in JavaDoc, this annotation does not require explicit JavaDoc and + * hence works for overridden methods that inherit API documentation directly from parents. + * The limitation is that it does not show up in the generated Java API documentation. + */ +private[spark] class Since(version: String) extends StaticAnnotation From a17384fa343628cec44437da5b80b9403ecd5838 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 12 Aug 2015 15:27:52 -0700 Subject: [PATCH 0200/1824] [SPARK-9907] [SQL] Python crc32 is mistakenly calling md5 Author: Reynold Xin Closes #8138 from rxin/SPARK-9907. --- python/pyspark/sql/functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 95f46044d324..e98979533f90 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -885,10 +885,10 @@ def crc32(col): returns the value as a bigint. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() - [Row(crc32=u'902fbdd2b1df0c4f70b4a5d23525e932')] + [Row(crc32=2743272264)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.md5(_to_java_column(col))) + return Column(sc._jvm.functions.crc32(_to_java_column(col))) @ignore_unicode_prefix From 738f353988dbf02704bd63f5e35d94402c59ed79 Mon Sep 17 00:00:00 2001 From: Niranjan Padmanabhan Date: Wed, 12 Aug 2015 16:10:21 -0700 Subject: [PATCH 0201/1824] [SPARK-9092] Fixed incompatibility when both num-executors and dynamic... MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … allocation are set. Now, dynamic allocation is set to false when num-executors is explicitly specified as an argument. Consequently, executorAllocationManager in not initialized in the SparkContext. Author: Niranjan Padmanabhan Closes #7657 from neurons/SPARK-9092. --- .../scala/org/apache/spark/SparkConf.scala | 19 +++++++++++++++++++ .../scala/org/apache/spark/SparkContext.scala | 6 +++++- .../org/apache/spark/deploy/SparkSubmit.scala | 4 ++-- .../scala/org/apache/spark/util/Utils.scala | 11 +++++++++++ .../org/apache/spark/SparkContextSuite.scala | 8 ++++++++ .../spark/deploy/SparkSubmitSuite.scala | 1 - docs/running-on-yarn.md | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 4 ++-- .../yarn/ApplicationMasterArguments.scala | 5 ----- .../org/apache/spark/deploy/yarn/Client.scala | 5 ++++- .../spark/deploy/yarn/ClientArguments.scala | 8 +------- .../spark/deploy/yarn/YarnAllocator.scala | 9 ++++++++- .../cluster/YarnClientSchedulerBackend.scala | 3 --- .../deploy/yarn/YarnAllocatorSuite.scala | 5 +++-- 14 files changed, 64 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 8ff154fb5e33..b344b5e173d6 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -389,6 +389,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { val driverOptsKey = "spark.driver.extraJavaOptions" val driverClassPathKey = "spark.driver.extraClassPath" val driverLibraryPathKey = "spark.driver.extraLibraryPath" + val sparkExecutorInstances = "spark.executor.instances" // Used by Yarn in 1.1 and before sys.props.get("spark.driver.libraryPath").foreach { value => @@ -476,6 +477,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } + + if (!contains(sparkExecutorInstances)) { + sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => + val warning = + s""" + |SPARK_WORKER_INSTANCES was detected (set to '$value'). + |This is deprecated in Spark 1.0+. + | + |Please instead use: + | - ./spark-submit with --num-executors to specify the number of executors + | - Or set SPARK_EXECUTOR_INSTANCES + | - spark.executor.instances to configure the number of instances in the spark config. + """.stripMargin + logWarning(warning) + + set("spark.executor.instances", value) + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6aafb4c5644d..207a0c1bffeb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -528,7 +528,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } // Optionally scale number of executors dynamically based on workload. Exposed for testing. - val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) + if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + } + _executorAllocationManager = if (dynamicAllocationEnabled) { Some(new ExecutorAllocationManager(this, listenerBus, _conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 7ac6cbce4cd1..02fa3088eded 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -422,7 +422,8 @@ object SparkSubmit { // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), - OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), + OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, + sysProp = "spark.executor.instances"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), @@ -433,7 +434,6 @@ object SparkSubmit { OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), - OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c4012d0e83f7..a90d8541366f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2286,6 +2286,17 @@ private[spark] object Utils extends Logging { isInDirectory(parent, child.getParentFile) } + /** + * Return whether dynamic allocation is enabled in the given conf + * Dynamic allocation and explicitly setting the number of executors are inherently + * incompatible. In environments where dynamic allocation is turned on by default, + * the latter should override the former (SPARK-9092). + */ + def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { + conf.contains("spark.dynamicAllocation.enabled") && + conf.getInt("spark.executor.instances", 0) == 0 + } + } private [util] class SparkShutdownHookManager { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 5c57940fa5f7..d4f2ea87650a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -285,4 +285,12 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("No exception when both num-executors and dynamic allocation set.") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local") + .set("spark.dynamicAllocation.enabled", "true").set("spark.executor.instances", "6")) + assert(sc.executorAllocationManager.isEmpty) + assert(sc.getConf.getInt("spark.executor.instances", 0) === 6) + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 757e0ce3d278..2456c5d0d49b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -159,7 +159,6 @@ class SparkSubmitSuite childArgsStr should include ("--executor-cores 5") childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include ("--queue thequeue") - childArgsStr should include ("--num-executors 6") childArgsStr should include regex ("--jar .*thejar.jar") childArgsStr should include regex ("--addJars .*one.jar,.*two.jar,.*three.jar") childArgsStr should include regex ("--files .*file1.txt,.*file2.txt") diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cac08a91b97d..ec32c419b7c5 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -199,7 +199,7 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1d67b3ebb51b..e19940d8d664 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -64,7 +64,8 @@ private[spark] class ApplicationMaster( // Default to numExecutors * 2, with minimum of 3 private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + sparkConf.getInt("spark.yarn.max.worker.failures", + math.max(sparkConf.getInt("spark.executor.instances", 0) * 2, 3))) @volatile private var exitCode = 0 @volatile private var unregistered = false @@ -493,7 +494,6 @@ private[spark] class ApplicationMaster( */ private def startUserApplication(): Thread = { logInfo("Starting the user application in a separate Thread") - System.setProperty("spark.executor.instances", args.numExecutors.toString) val classpath = Client.getUserClasspath(sparkConf) val urls = classpath.map { entry => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 37f793763367..b08412414aa1 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -29,7 +29,6 @@ class ApplicationMasterArguments(val args: Array[String]) { var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 - var numExecutors = DEFAULT_NUMBER_EXECUTORS var propertiesFile: String = null parseArgs(args.toList) @@ -63,10 +62,6 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgsBuffer += value args = tail - case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => - numExecutors = value - args = tail - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => executorMemory = value args = tail diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index b4ba3f022160..6d63ddaf1585 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -751,7 +751,6 @@ private[spark] class Client( userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString, "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) @@ -960,6 +959,10 @@ object Client extends Logging { val sparkConf = new SparkConf val args = new ClientArguments(argStrings, sparkConf) + // to maintain backwards-compatibility + if (!Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.setIfMissing("spark.executor.instances", args.numExecutors.toString) + } new Client(args, sparkConf).run() } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 20d63d40cf60..4f42ffefa77f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -53,8 +53,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) private val amMemOverheadKey = "spark.yarn.am.memoryOverhead" private val driverCoresKey = "spark.driver.cores" private val amCoresKey = "spark.yarn.am.cores" - private val isDynamicAllocationEnabled = - sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) + private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf) parseArgs(args.toList) loadEnvironmentArgs() @@ -196,11 +195,6 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--num-workers") { println("--num-workers is deprecated. Use --num-executors instead.") } - // Dynamic allocation is not compatible with this option - if (isDynamicAllocationEnabled) { - throw new IllegalArgumentException("Explicitly setting the number " + - "of executors is not compatible with spark.dynamicAllocation.enabled!") - } numExecutors = value args = tail diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 59caa787b6e2..ccf753e69f4b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,6 +21,8 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern +import org.apache.spark.util.Utils + import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -86,7 +88,12 @@ private[yarn] class YarnAllocator( private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 - @volatile private var targetNumExecutors = args.numExecutors + @volatile private var targetNumExecutors = + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + sparkConf.getInt("spark.dynamicAllocation.initialExecutors", 0) + } else { + sparkConf.getInt("spark.executor.instances", YarnSparkHadoopUtil.DEFAULT_NUMBER_EXECUTORS) + } // Keep track of which container is running which executor to remove the executors later // Visible for testing. diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d225061fcd1b..d06d95140438 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -81,8 +81,6 @@ private[spark] class YarnClientSchedulerBackend( // List of (target Client argument, environment variable, Spark property) val optionTuples = List( - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), - ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), @@ -92,7 +90,6 @@ private[spark] class YarnClientSchedulerBackend( ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( - "SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit", "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") optionTuples.foreach { case (optionName, envVar, sparkProp) => diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 58318bf9bcc0..5d05f514adde 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -87,16 +87,17 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter def createAllocator(maxExecutors: Int = 5): YarnAllocator = { val args = Array( - "--num-executors", s"$maxExecutors", "--executor-cores", "5", "--executor-memory", "2048", "--jar", "somejar.jar", "--class", "SomeClass") + val sparkConfClone = sparkConf.clone() + sparkConfClone.set("spark.executor.instances", maxExecutors.toString) new YarnAllocator( "not used", mock(classOf[RpcEndpointRef]), conf, - sparkConf, + sparkConfClone, rmClient, appAttemptId, new ApplicationMasterArguments(args), From ab7e721cfec63155641e81e72b4ad43cf6a7d4c7 Mon Sep 17 00:00:00 2001 From: Michel Lemay Date: Wed, 12 Aug 2015 16:17:58 -0700 Subject: [PATCH 0202/1824] [SPARK-9826] [CORE] Fix cannot use custom classes in log4j.properties Refactor Utils class and create ShutdownHookManager. NOTE: Wasn't able to run /dev/run-tests on windows machine. Manual tests were conducted locally using custom log4j.properties file with Redis appender and logstash formatter (bundled in the fat-jar submitted to spark) ex: log4j.rootCategory=WARN,console,redis log4j.appender.console=org.apache.log4j.ConsoleAppender log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO log4j.logger.org.apache.spark.graphx.Pregel=INFO log4j.appender.redis=com.ryantenney.log4j.FailoverRedisAppender log4j.appender.redis.endpoints=hostname:port log4j.appender.redis.key=mykey log4j.appender.redis.alwaysBatch=false log4j.appender.redis.layout=net.logstash.log4j.JSONEventLayoutV1 Author: michellemay Closes #8109 from michellemay/SPARK-9826. --- .../scala/org/apache/spark/SparkContext.scala | 5 +- .../spark/deploy/history/HistoryServer.scala | 4 +- .../spark/deploy/worker/ExecutorRunner.scala | 7 +- .../org/apache/spark/rdd/HadoopRDD.scala | 4 +- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 +- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 4 +- .../spark/storage/DiskBlockManager.scala | 10 +- .../spark/storage/TachyonBlockManager.scala | 6 +- .../spark/util/ShutdownHookManager.scala | 266 ++++++++++++++++++ .../util/SparkUncaughtExceptionHandler.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 222 +-------------- .../hive/thriftserver/HiveThriftServer2.scala | 4 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 4 +- .../apache/spark/sql/hive/test/TestHive.scala | 4 +- .../spark/streaming/StreamingContext.scala | 8 +- .../spark/deploy/yarn/ApplicationMaster.scala | 5 +- 16 files changed, 307 insertions(+), 252 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 207a0c1bffeb..2e01a9a18c78 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -563,7 +563,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. - _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + _shutdownHookRef = ShutdownHookManager.addShutdownHook( + ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => logInfo("Invoking stop() from shutdown hook") stop() } @@ -1671,7 +1672,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli return } if (_shutdownHookRef != null) { - Utils.removeShutdownHook(_shutdownHookRef) + ShutdownHookManager.removeShutdownHook(_shutdownHookRef) } Utils.tryLogNonFatalError { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a076a9c3f984..d4f327cc588f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -30,7 +30,7 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, Applica UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -238,7 +238,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Utils.addShutdownHook { () => server.stop() } + ShutdownHookManager.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 29a504228557..ab3fea475c2a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -28,7 +28,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender /** @@ -70,7 +70,8 @@ private[deploy] class ExecutorRunner( } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) } + shutdownHook = ShutdownHookManager.addShutdownHook { () => + killProcess(Some("Worker shutting down")) } } /** @@ -102,7 +103,7 @@ private[deploy] class ExecutorRunner( workerThread = null state = ExecutorState.KILLED try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: IllegalStateException => None } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f1c17369cb48..e1f8719eead0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -274,7 +274,7 @@ class HadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f83a051f5da1..6a9c004d65cf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -186,7 +186,7 @@ class NewHadoopRDD[K, V]( } } catch { case e: Exception => { - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 6a95e44c57fe..fa3fecc80cb6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} private[spark] class SqlNewHadoopPartition( @@ -212,7 +212,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } } catch { case e: Exception => - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { logWarning("Exception in RecordReader.close()", e) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 56a33d5ca7d6..3f8d26e1d4ca 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -22,7 +22,7 @@ import java.io.{IOException, File} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -144,7 +144,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } @@ -154,7 +154,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. try { - Utils.removeShutdownHook(shutdownHook) + ShutdownHookManager.removeShutdownHook(shutdownHook) } catch { case e: Exception => logError(s"Exception while removing shutdown hook.", e) @@ -168,7 +168,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(localDir)) { + Utils.deleteRecursively(localDir) + } } catch { case e: Exception => logError(s"Exception while deleting local spark dir: $localDir", e) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index ebad5bc5ab28..22878783fca6 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -32,7 +32,7 @@ import tachyon.TachyonURI import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -80,7 +80,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log // in order to avoid having really large inodes at the top level in Tachyon. tachyonDirs = createTachyonDirs() subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) + tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) } override def toString: String = {"ExternalBlockStore-Tachyon"} @@ -240,7 +240,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log logDebug("Shutdown hook called") tachyonDirs.foreach { tachyonDir => try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { Utils.deleteRecursively(tachyonDir, client) } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala new file mode 100644 index 000000000000..61ff9b89ec1c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.File +import java.util.PriorityQueue + +import scala.util.{Failure, Success, Try} +import tachyon.client.TachyonFile + +import org.apache.hadoop.fs.FileSystem +import org.apache.spark.Logging + +/** + * Various utility methods used by Spark. + */ +private[spark] object ShutdownHookManager extends Logging { + val DEFAULT_SHUTDOWN_PRIORITY = 100 + + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + + /** + * The shutdown priority of temp directory must be lower than the SparkContext shutdown + * priority. Otherwise cleaning the temp directories while Spark jobs are running can + * throw undesirable errors at the time of shutdown. + */ + val TEMP_DIR_SHUTDOWN_PRIORITY = 25 + + private lazy val shutdownHooks = { + val manager = new SparkShutdownHookManager() + manager.install() + manager + } + + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() + private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + + // Add a shutdown hook to delete the temp dirs when the JVM exits + addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => + logInfo("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + logInfo("Deleting directory " + dirPath) + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Register the tachyon path to be deleted via shutdown hook + def registerShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths += absolutePath + } + } + + // Remove the path to be deleted via shutdown hook + def removeShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(absolutePath) + } + } + + // Remove the tachyon path to be deleted via shutdown hook + def removeShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.remove(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in Exception and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + val retval = shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.exists { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + } + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ + def inShutdown(): Boolean = { + try { + val hook = new Thread { + override def run() {} + } + Runtime.getRuntime.addShutdownHook(hook) + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case ise: IllegalStateException => return true + } + false + } + + /** + * Adds a shutdown hook with default priority. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(hook: () => Unit): AnyRef = { + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) + } + + /** + * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * first. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { + shutdownHooks.add(priority, hook) + } + + /** + * Remove a previously installed shutdown hook. + * + * @param ref A handle returned by `addShutdownHook`. + * @return Whether the hook was removed. + */ + def removeShutdownHook(ref: AnyRef): Boolean = { + shutdownHooks.remove(ref) + } + +} + +private [util] class SparkShutdownHookManager { + + private val hooks = new PriorityQueue[SparkShutdownHook]() + private var shuttingDown = false + + /** + * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not + * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for + * the best. + */ + def install(): Unit = { + val hookTask = new Runnable() { + override def run(): Unit = runAll() + } + Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { + case Success(shmClass) => + val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() + .asInstanceOf[Int] + val shm = shmClass.getMethod("get").invoke(null) + shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) + .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) + + case Failure(_) => + Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + } + } + + def runAll(): Unit = synchronized { + shuttingDown = true + while (!hooks.isEmpty()) { + Try(Utils.logUncaughtExceptions(hooks.poll().run())) + } + } + + def add(priority: Int, hook: () => Unit): AnyRef = synchronized { + checkState() + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } + + def remove(ref: AnyRef): Boolean = synchronized { + hooks.remove(ref) + } + + private def checkState(): Unit = { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + } + +} + +private class SparkShutdownHook(private val priority: Int, hook: () => Unit) + extends Comparable[SparkShutdownHook] { + + override def compareTo(other: SparkShutdownHook): Int = { + other.priority - priority + } + + def run(): Unit = hook() + +} diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index ad3db1fbb57e..724818724733 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -33,7 +33,7 @@ private[spark] object SparkUncaughtExceptionHandler // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { + if (!ShutdownHookManager.inShutdown()) { if (exception.isInstanceOf[OutOfMemoryError]) { System.exit(SparkExitCode.OOM) } else { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a90d8541366f..f2abf227dc12 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,7 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{PriorityQueue, Properties, Locale, Random, UUID} +import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection @@ -65,21 +65,6 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() - val DEFAULT_SHUTDOWN_PRIORITY = 100 - - /** - * The shutdown priority of the SparkContext instance. This is lower than the default - * priority, so that by default hooks are run before the context is shut down. - */ - val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 - - /** - * The shutdown priority of temp directory must be lower than the SparkContext shutdown - * priority. Otherwise cleaning the temp directories while Spark jobs are running can - * throw undesirable errors at the time of shutdown. - */ - val TEMP_DIR_SHUTDOWN_PRIORITY = 25 - /** * Define a default value for driver memory here since this value is referenced across the code * base and nearly all files already use Utils.scala @@ -90,9 +75,6 @@ private[spark] object Utils extends Logging { @volatile private var localRootDirs: Array[String] = null - private val shutdownHooks = new SparkShutdownHookManager() - shutdownHooks.install() - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -205,86 +187,6 @@ private[spark] object Utils extends Logging { } } - private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() - - // Add a shutdown hook to delete the temp dirs when the JVM exits - addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => - logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => - try { - logInfo("Deleting directory " + dirPath) - Utils.deleteRecursively(new File(dirPath)) - } catch { - case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) - } - } - } - - // Register the path to be deleted via shutdown hook - def registerShutdownDeleteDir(file: File) { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths += absolutePath - } - } - - // Register the tachyon path to be deleted via shutdown hook - def registerShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in IOException and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - val retval = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - /** * JDK equivalent of `chmod 700 file`. * @@ -333,7 +235,7 @@ private[spark] object Utils extends Logging { root: String = System.getProperty("java.io.tmpdir"), namePrefix: String = "spark"): File = { val dir = createDirectory(root, namePrefix) - registerShutdownDeleteDir(dir) + ShutdownHookManager.registerShutdownDeleteDir(dir) dir } @@ -973,9 +875,7 @@ private[spark] object Utils extends Logging { if (savedIOException != null) { throw savedIOException } - shutdownDeletePaths.synchronized { - shutdownDeletePaths.remove(file.getAbsolutePath) - } + ShutdownHookManager.removeShutdownDeleteDir(file) } } finally { if (!file.delete()) { @@ -1478,27 +1378,6 @@ private[spark] object Utils extends Logging { serializer.deserialize[T](serializer.serialize(value)) } - /** - * Detect whether this thread might be executing a shutdown hook. Will always return true if - * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. - * if System.exit was just called by a concurrent thread). - * - * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing - * an IllegalStateException. - */ - def inShutdown(): Boolean = { - try { - val hook = new Thread { - override def run() {} - } - Runtime.getRuntime.addShutdownHook(hook) - Runtime.getRuntime.removeShutdownHook(hook) - } catch { - case ise: IllegalStateException => return true - } - false - } - private def isSpace(c: Char): Boolean = { " \t\r\n".indexOf(c) != -1 } @@ -2221,37 +2100,6 @@ private[spark] object Utils extends Logging { msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX) } - /** - * Adds a shutdown hook with default priority. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) - } - - /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run - * first. - * - * @param hook The code to run during shutdown. - * @return A handle that can be used to unregister the shutdown hook. - */ - def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { - shutdownHooks.add(priority, hook) - } - - /** - * Remove a previously installed shutdown hook. - * - * @param ref A handle returned by `addShutdownHook`. - * @return Whether the hook was removed. - */ - def removeShutdownHook(ref: AnyRef): Boolean = { - shutdownHooks.remove(ref) - } - /** * To avoid calling `Utils.getCallSite` for every single RDD we create in the body, * set a dummy call site that RDDs use instead. This is for performance optimization. @@ -2299,70 +2147,6 @@ private[spark] object Utils extends Logging { } -private [util] class SparkShutdownHookManager { - - private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false - - /** - * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not - * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for - * the best. - */ - def install(): Unit = { - val hookTask = new Runnable() { - override def run(): Unit = runAll() - } - Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - } - } - - def runAll(): Unit = synchronized { - shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) - } - } - - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef - } - - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) - } - - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } - } - -} - -private class SparkShutdownHook(private val priority: Int, hook: () => Unit) - extends Comparable[SparkShutdownHook] { - - override def compareTo(other: SparkShutdownHook): Int = { - other.priority - priority - } - - def run(): Unit = hook() - -} - /** * A utility class to redirect the child process's stdout or stderr. */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 9c047347cb58..2c9fa595b2da 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{Logging, SparkContext} @@ -76,7 +76,7 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Utils.addShutdownHook { () => + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() uiTab.foreach(_.detach()) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index d3886142b388..7799704c819d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -39,7 +39,7 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * This code doesn't support remote connections in Hive 1.2+, as the underlying CliDriver @@ -114,7 +114,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { SessionState.start(sessionState) // Clean up after we exit - Utils.addShutdownHook { () => SparkSQLEnv.stop() } + ShutdownHookManager.addShutdownHook { () => SparkSQLEnv.stop() } val remoteMode = isRemoteMode(sessionState) // "-h" option has been passed, so connect to Hive thrift server. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 296cc5c5e0b0..4eae699ac3b5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} /* Implicit conversions */ @@ -154,7 +154,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() - Utils.registerShutdownDeleteDir(hiveFilesTemp) + ShutdownHookManager.registerShutdownDeleteDir(hiveFilesTemp) val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 177e710ace54..b496d1f341a0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -604,7 +604,7 @@ class StreamingContext private[streaming] ( } StreamingContext.setActiveContext(this) } - shutdownHookRef = Utils.addShutdownHook( + shutdownHookRef = ShutdownHookManager.addShutdownHook( StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) // Registering Streaming Metrics at the start of the StreamingContext assert(env.metricsSystem != null) @@ -691,7 +691,7 @@ class StreamingContext private[streaming] ( StreamingContext.setActiveContext(null) waiter.notifyStop() if (shutdownHookRef != null) { - Utils.removeShutdownHook(shutdownHookRef) + ShutdownHookManager.removeShutdownHook(shutdownHookRef) } logInfo("StreamingContext stopped successfully") } @@ -725,7 +725,7 @@ object StreamingContext extends Logging { */ private val ACTIVATION_LOCK = new Object() - private val SHUTDOWN_HOOK_PRIORITY = Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 + private val SHUTDOWN_HOOK_PRIORITY = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 private val activeContext = new AtomicReference[StreamingContext](null) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e19940d8d664..6a8ddb37b29e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -112,7 +112,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) // This shutdown hook should run *after* the SparkContext is shut down. - Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => + val priority = ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1 + ShutdownHookManager.addShutdownHook(priority) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts @@ -199,7 +200,7 @@ private[spark] class ApplicationMaster( final def finish(status: FinalApplicationStatus, code: Int, msg: String = null): Unit = { synchronized { if (!finished) { - val inShutdown = Utils.inShutdown() + val inShutdown = ShutdownHookManager.inShutdown() logInfo(s"Final app status: $status, exitCode: $code" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) exitCode = code From 7035d880a0cf06910c19b4afd49645124c620f14 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 16:45:15 -0700 Subject: [PATCH 0203/1824] [SPARK-9894] [SQL] Json writer should handle MapData. https://issues.apache.org/jira/browse/SPARK-9894 Author: Yin Huai Closes #8137 from yhuai/jsonMapData. --- .../datasources/json/JacksonGenerator.scala | 10 +-- .../sources/JsonHadoopFsRelationSuite.scala | 78 +++++++++++++++++++ .../SimpleTextHadoopFsRelationSuite.scala | 30 ------- 3 files changed, 83 insertions(+), 35 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 37c2b5a296c1..99ac7730bd1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -107,12 +107,12 @@ private[sql] object JacksonGenerator { v.foreach(ty, (_, value) => valWriter(ty, value)) gen.writeEndArray() - case (MapType(kv, vv, _), v: Map[_, _]) => + case (MapType(kt, vt, _), v: MapData) => gen.writeStartObject() - v.foreach { p => - gen.writeFieldName(p._1.toString) - valWriter(vv, p._2) - } + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) gen.writeEndObject() case (StructType(ty), v: InternalRow) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala new file mode 100644 index 000000000000..ed6d512ab36f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = "json" + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-9894: save complex types to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("array", ArrayType(LongType)) + .add("map", MapType(StringType, new StructType().add("innerField", LongType))) + + val data = + Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: + Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil + val df = createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 48c37a1fa102..e8975e5f5cd0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -50,33 +50,3 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } } - -class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = - classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource].getCanonicalName - - import sqlContext._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } -} From caa14d9dc9e2eb1102052b22445b63b0e004e3c7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 16:53:47 -0700 Subject: [PATCH 0204/1824] [SPARK-9913] [MLLIB] LDAUtils should be private feynmanliang Author: Xiangrui Meng Closes #8142 from mengxr/SPARK-9913. --- .../main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index f7e5ce1665fe..a9ba7b60bad0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -22,7 +22,7 @@ import breeze.numerics._ /** * Utility methods for LDA. */ -object LDAUtils { +private[clustering] object LDAUtils { /** * Log Sum Exp with overflow protection using the identity: * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\} From 6e409bc1357f49de2efdfc4226d074b943fb1153 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 12 Aug 2015 16:54:45 -0700 Subject: [PATCH 0205/1824] [SPARK-9909] [ML] [TRIVIAL] move weightCol to shared params As per the TODO move weightCol to Shared Params. Author: Holden Karau Closes #8144 from holdenk/SPARK-9909-move-weightCol-toSharedParams. --- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +++- .../spark/ml/param/shared/sharedParams.scala | 15 +++++++++++++++ .../spark/ml/regression/IsotonicRegression.scala | 16 ++-------------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 9e12f1856a94..8c16c6149b40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -70,7 +70,9 @@ private[shared] object SharedParamsCodeGen { " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), - ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization.")) + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), + ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + + "all instance weights as 1.0.")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a17d4ea960a9..c26768953e3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -342,4 +342,19 @@ private[ml] trait HasStepSize extends Params { /** @group getParam */ final def getStepSize: Double = $(stepSize) } + +/** + * Trait for shared param weightCol. + */ +private[ml] trait HasWeightCol extends Params { + + /** + * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.. + * @group param + */ + final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index f570590960a6..0f33bae30e62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -21,7 +21,7 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} @@ -35,19 +35,7 @@ import org.apache.spark.storage.StorageLevel * Params for isotonic regression. */ private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol - with HasLabelCol with HasPredictionCol with Logging { - - /** - * Param for weight column name (default: none). - * @group param - */ - // TODO: Move weightCol to sharedParams. - final val weightCol: Param[String] = - new Param[String](this, "weightCol", - "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") - - /** @group getParam */ - final def getWeightCol: String = $(weightCol) + with HasLabelCol with HasPredictionCol with HasWeightCol with Logging { /** * Param for whether the output sequence should be isotonic/increasing (true) or From e6aef55766d0e2a48e0f9cb6eda0e31a71b962f3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 17:04:31 -0700 Subject: [PATCH 0206/1824] [SPARK-9912] [MLLIB] QRDecomposition should use QType and RType for type names instead of UType and VType hhbyyh Author: Xiangrui Meng Closes #8140 from mengxr/SPARK-9912. --- .../apache/spark/mllib/linalg/SingularValueDecomposition.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index b416d50a5631..cff5dbeee3e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -31,5 +31,5 @@ case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VTyp * Represents QR factors. */ @Experimental -case class QRDecomposition[UType, VType](Q: UType, R: VType) +case class QRDecomposition[QType, RType](Q: QType, R: RType) From fc1c7fd66e64ccea53b31cd2fbb98bc6d307329c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 17:06:12 -0700 Subject: [PATCH 0207/1824] [SPARK-9915] [ML] stopWords should use StringArrayParam hhbyyh Author: Xiangrui Meng Closes #8141 from mengxr/SPARK-9915. --- .../org/apache/spark/ml/feature/StopWordsRemover.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3cc41424460f..5d77ea08db65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -19,12 +19,12 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, BooleanParam, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.{StringType, StructField, ArrayType, StructType} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} /** * stop words list @@ -100,7 +100,7 @@ class StopWordsRemover(override val uid: String) * the stop words set to be filtered out * @group param */ - val stopWords: Param[Array[String]] = new Param(this, "stopWords", "stop words") + val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words") /** @group setParam */ def setStopWords(value: Array[String]): this.type = set(stopWords, value) From 660e6dcff8125b83cc73dbe00c90cbe58744bc66 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 12 Aug 2015 17:07:29 -0700 Subject: [PATCH 0208/1824] [SPARK-9449] [SQL] Include MetastoreRelation's inputFiles Author: Michael Armbrust Closes #8119 from marmbrus/metastoreInputFiles. --- .../org/apache/spark/sql/DataFrame.scala | 10 ++++--- .../spark/sql/execution/FileRelation.scala | 28 +++++++++++++++++++ .../apache/spark/sql/sources/interfaces.scala | 6 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 26 +++++++++-------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 16 +++++++++-- 5 files changed, 66 insertions(+), 20 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 27b994f1f0ca..c466d9e6cb34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,10 +34,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -1560,8 +1560,10 @@ class DataFrame private[sql]( */ def inputFiles: Array[String] = { val files: Seq[String] = logicalPlan.collect { - case LogicalRelation(fsBasedRelation: HadoopFsRelation) => - fsBasedRelation.paths.toSeq + case LogicalRelation(fsBasedRelation: FileRelation) => + fsBasedRelation.inputFiles + case fr: FileRelation => + fr.inputFiles }.flatten files.toSet.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala new file mode 100644 index 000000000000..7a2a9eed5807 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * An interface for relations that are backed by files. When a class implements this interface, + * the list of paths that it returns will be returned to a user who calls `inputPaths` on any + * DataFrame that queries this relation. + */ +private[sql] trait FileRelation { + /** Returns the list of files that will be read when scanning this relation. */ + def inputFiles: Array[String] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 2f8417a48d32..b3b326fe612c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.execution.RDDConversions +import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} import org.apache.spark.sql.types.StructType import org.apache.spark.sql._ @@ -406,7 +406,7 @@ abstract class OutputWriter { */ @Experimental abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) - extends BaseRelation with Logging { + extends BaseRelation with FileRelation with Logging { override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") @@ -516,6 +516,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio */ def paths: Array[String] + override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + /** * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically * discovered. Note that they should always be nullable. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index adbd95197d7c..2feec29955bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -485,21 +485,23 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("inputFiles") { - val fakeRelation1 = new ParquetRelation(Array("/my/path", "/my/other/path"), - Some(testData.schema), None, Map.empty)(sqlContext) - val df1 = DataFrame(sqlContext, LogicalRelation(fakeRelation1)) - assert(df1.inputFiles.toSet == fakeRelation1.paths.toSet) + withTempDir { dir => + val df = Seq((1, 22)).toDF("a", "b") - val fakeRelation2 = new JSONRelation( - None, 1, Some(testData.schema), None, None, Array("/json/path"))(sqlContext) - val df2 = DataFrame(sqlContext, LogicalRelation(fakeRelation2)) - assert(df2.inputFiles.toSet == fakeRelation2.paths.toSet) + val parquetDir = new File(dir, "parquet").getCanonicalPath + df.write.parquet(parquetDir) + val parquetDF = sqlContext.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) - val unionDF = df1.unionAll(df2) - assert(unionDF.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.paths) + val jsonDir = new File(dir, "json").getCanonicalPath + df.write.json(jsonDir) + val jsonDF = sqlContext.read.json(jsonDir) + assert(parquetDF.inputFiles.nonEmpty) - val filtered = df1.filter("false").unionAll(df2.intersect(df2)) - assert(filtered.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.paths) + val unioned = jsonDF.unionAll(parquetDF).inputFiles.sorted + val allFiles = (jsonDF.inputFiles ++ parquetDF.inputFiles).toSet.toArray.sorted + assert(unioned === allFiles) + } } ignore("show") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ac9aaed19d56..5e5497837a39 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation @@ -739,7 +739,7 @@ private[hive] case class MetastoreRelation (databaseName: String, tableName: String, alias: Option[String]) (val table: HiveTable) (@transient sqlContext: SQLContext) - extends LeafNode with MultiInstanceRelation { + extends LeafNode with MultiInstanceRelation with FileRelation { override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => @@ -888,6 +888,18 @@ private[hive] case class MetastoreRelation /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) + override def inputFiles: Array[String] = { + val partLocations = table.getPartitions(Nil).map(_.storage.location).toArray + if (partLocations.nonEmpty) { + partLocations + } else { + Array( + table.location.getOrElse( + sys.error(s"Could not get the location of ${table.qualifiedName}."))) + } + } + + override def newInstance(): MetastoreRelation = { MetastoreRelation(databaseName, tableName, alias)(table)(sqlContext) } From 8ce60963cb0928058ef7b6e29ff94eb69d1143af Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Wed, 12 Aug 2015 17:44:16 -0700 Subject: [PATCH 0209/1824] =?UTF-8?q?[SPARK-9780]=20[STREAMING]=20[KAFKA]?= =?UTF-8?q?=20prevent=20NPE=20if=20KafkaRDD=20instantiation=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …fails Author: cody koeninger Closes #8133 from koeninger/SPARK-9780 and squashes the following commits: 406259d [cody koeninger] [SPARK-9780][Streaming][Kafka] prevent NPE if KafkaRDD instantiation fails --- .../scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 1a9d78c0d4f5..ea5f842c6caf 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -197,7 +197,11 @@ class KafkaRDD[ .dropWhile(_.offset < requestOffset) } - override def close(): Unit = consumer.close() + override def close(): Unit = { + if (consumer != null) { + consumer.close() + } + } override def getNext(): R = { if (iter == null || !iter.hasNext) { From 0d1d146c220f0d47d0e62b368d5b94d3bd9dd197 Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Wed, 12 Aug 2015 17:48:43 -0700 Subject: [PATCH 0210/1824] [SPARK-9724] [WEB UI] Avoid unnecessary redirects in the Spark Web UI. Author: Rohit Agarwal Closes #8014 from mindprince/SPARK-9724 and squashes the following commits: a7af5ff [Rohit Agarwal] [SPARK-9724] [WEB UI] Inline attachPrefix and attachPrefixForRedirect. Fix logic of attachPrefix 8a977cd [Rohit Agarwal] [SPARK-9724] [WEB UI] Address review comments: Remove unneeded code, update scaladoc. b257844 [Rohit Agarwal] [SPARK-9724] [WEB UI] Avoid unnecessary redirects in the Spark Web UI. --- .../main/scala/org/apache/spark/ui/JettyUtils.scala | 13 ++++++------- .../main/scala/org/apache/spark/ui/SparkUI.scala | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index c8356467fab8..779c0ba08359 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -106,7 +106,11 @@ private[spark] object JettyUtils extends Logging { path: String, servlet: HttpServlet, basePath: String): ServletContextHandler = { - val prefixedPath = attachPrefix(basePath, path) + val prefixedPath = if (basePath == "" && path == "/") { + path + } else { + (basePath + path).stripSuffix("/") + } val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) contextHandler.setContextPath(prefixedPath) @@ -121,7 +125,7 @@ private[spark] object JettyUtils extends Logging { beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = "", httpMethods: Set[String] = Set("GET")): ServletContextHandler = { - val prefixedDestPath = attachPrefix(basePath, destPath) + val prefixedDestPath = basePath + destPath val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { if (httpMethods.contains("GET")) { @@ -246,11 +250,6 @@ private[spark] object JettyUtils extends Logging { val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } - - /** Attach a prefix to the given path, but avoid returning an empty path */ - private def attachPrefix(basePath: String, relativePath: String): String = { - if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/") - } } private[spark] case class ServerInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 3788916cf39b..d8b90568b7b9 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -64,11 +64,11 @@ private[spark] class SparkUI private ( attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( - "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, + "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest, httpMethods = Set("GET", "POST"))) } initialize() From f4bc01f1f33a93e6affe5c8a3e33ffbd92d03f38 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 12 Aug 2015 18:33:27 -0700 Subject: [PATCH 0211/1824] [SPARK-9855] [SPARKR] Add expression functions into SparkR whose params are simple I added lots of expression functions for SparkR. This PR includes only functions whose params are only `(Column)` or `(Column, Column)`. And I think we need to improve how to test those functions. However, it would be better to work on another issue. ## Diff Summary - Add lots of functions in `functions.R` and their generic in `generic.R` - Add aliases for `ceiling` and `sign` - Move expression functions from `column.R` to `functions.R` - Modify `rdname` from `column` to `functions` I haven't supported `not` function, because the name has a collesion with `testthat` package. I didn't think of the way to define it. ## New Supported Functions ``` approxCountDistinct ascii base64 bin bitwiseNOT ceil (alias: ceiling) crc32 dayofmonth dayofyear explode factorial hex hour initcap isNaN last_day length log2 ltrim md5 minute month negate quarter reverse round rtrim second sha1 signum (alias: sign) size soundex to_date trim unbase64 unhex weekofyear year datediff levenshtein months_between nanvl pmod ``` ## JIRA [[SPARK-9855] Add expression functions into SparkR whose params are simple - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9855) Author: Yu ISHIKAWA Closes #8123 from yu-iskw/SPARK-9855. --- R/pkg/DESCRIPTION | 1 + R/pkg/R/column.R | 81 -------------- R/pkg/R/functions.R | 123 ++++++++++++++++++++ R/pkg/R/generics.R | 185 +++++++++++++++++++++++++++++-- R/pkg/inst/tests/test_sparkSQL.R | 21 ++-- 5 files changed, 309 insertions(+), 102 deletions(-) create mode 100644 R/pkg/R/functions.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 4949d86d20c9..83e64897216b 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -29,6 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'functions.R' 'mllib.R' 'serialize.R' 'sparkR.R' diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index eeaf9f193b72..328f595d0805 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -60,12 +60,6 @@ operators <- list( ) column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") -functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct", - "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", - "expm1", "floor", "log", "log10", "log1p", "rint", "sign", - "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions <- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -111,33 +105,6 @@ createColumnFunction2 <- function(name) { }) } -createStaticFunction <- function(name) { - setMethod(name, - signature(x = "Column"), - function(x) { - if (name == "ceiling") { - name <- "ceil" - } - if (name == "sign") { - name <- "signum" - } - jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) - column(jc) - }) -} - -createBinaryMathfunctions <- function(name) { - setMethod(name, - signature(y = "Column"), - function(y, x) { - if (class(x) == "Column") { - x <- x@jc - } - jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) - column(jc) - }) -} - createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -148,12 +115,6 @@ createMethods <- function() { for (name in column_functions2) { createColumnFunction2(name) } - for (x in functions) { - createStaticFunction(x) - } - for (name in binary_mathfunctions) { - createBinaryMathfunctions(name) - } } createMethods() @@ -242,45 +203,3 @@ setMethod("%in%", jc <- callJMethod(x@jc, "in", table) return(column(jc)) }) - -#' Approx Count Distinct -#' -#' @rdname column -#' @return the approximate number of distinct items in a group. -setMethod("approxCountDistinct", - signature(x = "Column"), - function(x, rsd = 0.95) { - jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) - column(jc) - }) - -#' Count Distinct -#' -#' @rdname column -#' @return the number of distinct items in a group. -setMethod("countDistinct", - signature(x = "Column"), - function(x, ...) { - jcol <- lapply(list(...), function (x) { - x@jc - }) - jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) - column(jc) - }) - -#' @rdname column -#' @aliases countDistinct -setMethod("n_distinct", - signature(x = "Column"), - function(x, ...) { - countDistinct(x, ...) - }) - -#' @rdname column -#' @aliases count -setMethod("n", - signature(x = "Column"), - function(x) { - count(x) - }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R new file mode 100644 index 000000000000..a15d2d5da534 --- /dev/null +++ b/R/pkg/R/functions.R @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +#' @include generics.R column.R +NULL + +#' @title S4 expression functions for DataFrame column(s) +#' @description These are expression functions on DataFrame columns + +functions1 <- c( + "abs", "acos", "approxCountDistinct", "ascii", "asin", "atan", + "avg", "base64", "bin", "bitwiseNOT", "cbrt", "ceil", "cos", "cosh", "count", + "crc32", "dayofmonth", "dayofyear", "exp", "explode", "expm1", "factorial", + "first", "floor", "hex", "hour", "initcap", "isNaN", "last", "last_day", + "length", "log", "log10", "log1p", "log2", "lower", "ltrim", "max", "md5", + "mean", "min", "minute", "month", "negate", "quarter", "reverse", + "rint", "round", "rtrim", "second", "sha1", "signum", "sin", "sinh", "size", + "soundex", "sqrt", "sum", "sumDistinct", "tan", "tanh", "toDegrees", + "toRadians", "to_date", "trim", "unbase64", "unhex", "upper", "weekofyear", + "year") +functions2 <- c( + "atan2", "datediff", "hypot", "levenshtein", "months_between", "nanvl", "pmod") + +createFunction1 <- function(name) { + setMethod(name, + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) + column(jc) + }) +} + +createFunction2 <- function(name) { + setMethod(name, + signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) + column(jc) + }) +} + +createFunctions <- function() { + for (name in functions1) { + createFunction1(name) + } + for (name in functions2) { + createFunction2(name) + } +} + +createFunctions() + +#' Approx Count Distinct +#' +#' @rdname functions +#' @return the approximate number of distinct items in a group. +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x, rsd = 0.95) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) + column(jc) + }) + +#' Count Distinct +#' +#' @rdname functions +#' @return the number of distinct items in a group. +setMethod("countDistinct", + signature(x = "Column"), + function(x, ...) { + jcol <- lapply(list(...), function (x) { + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, + listToSeq(jcol)) + column(jc) + }) + +#' @rdname functions +#' @aliases ceil +setMethod("ceiling", + signature(x = "Column"), + function(x) { + ceil(x) + }) + +#' @rdname functions +#' @aliases signum +setMethod("sign", signature(x = "Column"), + function(x) { + signum(x) + }) + +#' @rdname functions +#' @aliases countDistinct +setMethod("n_distinct", signature(x = "Column"), + function(x, ...) { + countDistinct(x, ...) + }) + +#' @rdname functions +#' @aliases count +setMethod("n", signature(x = "Column"), + function(x) { + count(x) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 379a78b1d833..f11e7fcb6a07 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -575,10 +575,6 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @export setGeneric("asc", function(x) { standardGeneric("asc") }) -#' @rdname column -#' @export -setGeneric("avg", function(x, ...) { standardGeneric("avg") }) - #' @rdname column #' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) @@ -587,13 +583,10 @@ setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) -#' @rdname column -#' @export -setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) - #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) + #' @rdname column #' @export setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) @@ -658,22 +651,190 @@ setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @export setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) -#' @rdname column + +###################### Expression Function Methods ########################## + +#' @rdname functions +#' @export +setGeneric("ascii", function(x) { standardGeneric("ascii") }) + +#' @rdname functions +#' @export +setGeneric("avg", function(x, ...) { standardGeneric("avg") }) + +#' @rdname functions +#' @export +setGeneric("base64", function(x) { standardGeneric("base64") }) + +#' @rdname functions +#' @export +setGeneric("bin", function(x) { standardGeneric("bin") }) + +#' @rdname functions +#' @export +setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) + +#' @rdname functions +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + +#' @rdname functions +#' @export +setGeneric("ceil", function(x) { standardGeneric("ceil") }) + +#' @rdname functions +#' @export +setGeneric("crc32", function(x) { standardGeneric("crc32") }) + +#' @rdname functions +#' @export +setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) + +#' @rdname functions +#' @export +setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) + +#' @rdname functions +#' @export +setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) + +#' @rdname functions +#' @export +setGeneric("explode", function(x) { standardGeneric("explode") }) + +#' @rdname functions +#' @export +setGeneric("hex", function(x) { standardGeneric("hex") }) + +#' @rdname functions +#' @export +setGeneric("hour", function(x) { standardGeneric("hour") }) + +#' @rdname functions +#' @export +setGeneric("initcap", function(x) { standardGeneric("initcap") }) + +#' @rdname functions +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + +#' @rdname functions +#' @export +setGeneric("last_day", function(x) { standardGeneric("last_day") }) + +#' @rdname functions +#' @export +setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) + +#' @rdname functions +#' @export +setGeneric("lower", function(x) { standardGeneric("lower") }) + +#' @rdname functions +#' @export +setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) + +#' @rdname functions +#' @export +setGeneric("md5", function(x) { standardGeneric("md5") }) + +#' @rdname functions +#' @export +setGeneric("minute", function(x) { standardGeneric("minute") }) + +#' @rdname functions +#' @export +setGeneric("month", function(x) { standardGeneric("month") }) + +#' @rdname functions +#' @export +setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) + +#' @rdname functions +#' @export +setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) + +#' @rdname functions +#' @export +setGeneric("negate", function(x) { standardGeneric("negate") }) + +#' @rdname functions +#' @export +setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) + +#' @rdname functions +#' @export +setGeneric("quarter", function(x) { standardGeneric("quarter") }) + +#' @rdname functions +#' @export +setGeneric("reverse", function(x) { standardGeneric("reverse") }) + +#' @rdname functions +#' @export +setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) + +#' @rdname functions +#' @export +setGeneric("second", function(x) { standardGeneric("second") }) + +#' @rdname functions +#' @export +setGeneric("sha1", function(x) { standardGeneric("sha1") }) + +#' @rdname functions +#' @export +setGeneric("signum", function(x) { standardGeneric("signum") }) + +#' @rdname functions +#' @export +setGeneric("size", function(x) { standardGeneric("size") }) + +#' @rdname functions +#' @export +setGeneric("soundex", function(x) { standardGeneric("soundex") }) + +#' @rdname functions #' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname column +#' @rdname functions #' @export setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname column +#' @rdname functions #' @export setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname column +#' @rdname functions +#' @export +setGeneric("to_date", function(x) { standardGeneric("to_date") }) + +#' @rdname functions +#' @export +setGeneric("trim", function(x) { standardGeneric("trim") }) + +#' @rdname functions +#' @export +setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) + +#' @rdname functions +#' @export +setGeneric("unhex", function(x) { standardGeneric("unhex") }) + +#' @rdname functions #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) +#' @rdname functions +#' @export +setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) + +#' @rdname functions +#' @export +setGeneric("year", function(x) { standardGeneric("year") }) + + #' @rdname glm #' @export setGeneric("glm") diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 7377fc8f1ca9..e6d3b21ff825 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -640,15 +640,18 @@ test_that("column operators", { test_that("column functions", { c <- SparkR:::col("a") - c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) - c3 <- lower(c) + upper(c) + first(c) + last(c) - c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") - c5 <- n(c) + n_distinct(c) - c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) - c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) - c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) - c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) - c9 <- toDegrees(c) + toRadians(c) + c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) + c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) + c3 <- cosh(c) + count(c) + crc32(c) + dayofmonth(c) + dayofyear(c) + exp(c) + c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) + c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) + c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) + c7 <- mean(c) + min(c) + minute(c) + month(c) + negate(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + second(c) + sha1(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) + c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) + c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + weekofyear(c) + c12 <- year(c) df <- jsonFile(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) From 7b13ed27c1296cf76d0946e400f3449c335c8471 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 12 Aug 2015 18:52:11 -0700 Subject: [PATCH 0212/1824] [SPARK-9870] Disable driver UI and Master REST server in SparkSubmitSuite I think that we should pass additional configuration flags to disable the driver UI and Master REST server in SparkSubmitSuite and HiveSparkSubmitSuite. This might cut down on port-contention-related flakiness in Jenkins. Author: Josh Rosen Closes #8124 from JoshRosen/disable-ui-in-sparksubmitsuite. --- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 7 +++++++ .../apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 10 +++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2456c5d0d49b..1110ca6051a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -324,6 +324,8 @@ class SparkSubmitSuite "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -337,6 +339,8 @@ class SparkSubmitSuite "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -355,6 +359,7 @@ class SparkSubmitSuite "--packages", Seq(main, dep).mkString(","), "--repositories", repo, "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -500,6 +505,8 @@ class SparkSubmitSuite "--master", "local", "--conf", "spark.driver.extraClassPath=" + systemJar, "--conf", "spark.driver.userClassPathFirst=true", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", userJar.toString) runSparkSubmit(args) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index b8d41065d3f0..1e1972d1ac35 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -57,6 +57,8 @@ class HiveSparkSubmitSuite "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), "--name", "SparkSubmitClassLoaderTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -68,6 +70,8 @@ class HiveSparkSubmitSuite "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -79,7 +83,11 @@ class HiveSparkSubmitSuite // the HiveContext code mistakenly overrides the class loader that contains user classes. // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" - val args = Seq("--class", "Main", testJar) + val args = Seq( + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--class", "Main", + testJar) runSparkSubmit(args) } From 7c35746c916cf0019367850e75a080d7e739dba0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 12 Aug 2015 20:02:55 -0700 Subject: [PATCH 0213/1824] [SPARK-9827] [SQL] fix fd leak in UnsafeRowSerializer Currently, UnsafeRowSerializer does not close the InputStream, will cause fd leak if the InputStream has an open fd in it. TODO: the fd could still be leaked, if any items in the stream is not consumed. Currently it replies on GC to close the fd in this case. cc JoshRosen Author: Davies Liu Closes #8116 from davies/fd_leak. --- .../sql/execution/UnsafeRowSerializer.scala | 2 ++ .../execution/UnsafeRowSerializerSuite.scala | 31 +++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 3860c4bba9a9..5c18558f9bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -108,6 +108,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = { new Iterator[(Int, UnsafeRow)] { private[this] var rowSize: Int = dIn.readInt() + if (rowSize == EOF) dIn.close() override def hasNext: Boolean = rowSize != EOF @@ -119,6 +120,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream + dIn.close() val _rowTuple = rowTuple // Null these out so that the byte array can be garbage collected once the entire // iterator has been consumed diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 40b47ae18d64..bd02c73a26ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -25,6 +25,18 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ + +/** + * used to test close InputStream in UnsafeRowSerializer + */ +class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStream(buf) { + var closed: Boolean = false + override def close(): Unit = { + closed = true + super.close() + } +} + class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { @@ -52,8 +64,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { serializerStream.writeValue(unsafeRow) } serializerStream.close() - val deserializerIter = serializer.deserializeStream( - new ByteArrayInputStream(baos.toByteArray)).asKeyValueIterator + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator for (expectedRow <- unsafeRows) { val actualRow = deserializerIter.next().asInstanceOf[(Integer, UnsafeRow)]._2 assert(expectedRow.getSizeInBytes === actualRow.getSizeInBytes) @@ -61,5 +73,18 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { assert(expectedRow.getInt(1) === actualRow.getInt(1)) } assert(!deserializerIter.hasNext) + assert(input.closed) + } + + test("close empty input stream") { + val baos = new ByteArrayOutputStream() + val dout = new DataOutputStream(baos) + dout.writeInt(-1) // EOF + dout.flush() + val input = new ClosableByteArrayInputStream(baos.toByteArray) + val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() + val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator + assert(!deserializerIter.hasNext) + assert(input.closed) } } From 4413d0855aaba5cb00f737dc6934a0b92d9bc05d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 20:03:55 -0700 Subject: [PATCH 0214/1824] [SPARK-9908] [SQL] When spark.sql.tungsten.enabled is false, broadcast join does not work https://issues.apache.org/jira/browse/SPARK-9908 Author: Yin Huai Closes #8149 from yhuai/SPARK-9908. --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 076afe6e4e96..bb333b4d5ed1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -66,7 +66,8 @@ private[joins] final class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) @@ -88,7 +89,8 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { - private def this() = this(null) // Needed for serialization + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) override def get(key: InternalRow): Seq[InternalRow] = { val v = hashTable.get(key) From d2d5e7fe2df582e1c866334b3014d7cb351f5b70 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 12 Aug 2015 20:43:36 -0700 Subject: [PATCH 0215/1824] [SPARK-9704] [ML] Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs Made ProbabilisticClassifier, Identifiable, VectorUDT public. All are annotated as DeveloperApi. CC: mengxr EronWright Author: Joseph K. Bradley Closes #8004 from jkbradley/ml-api-public-items and squashes the following commits: 7ebefda [Joseph K. Bradley] update per code review 7ff0768 [Joseph K. Bradley] attepting to add mima fix 756d84c [Joseph K. Bradley] VectorUDT annotated as AlphaComponent ae7767d [Joseph K. Bradley] added another warning 94fd553 [Joseph K. Bradley] Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs --- .../classification/ProbabilisticClassifier.scala | 4 ++-- .../org/apache/spark/ml/util/Identifiable.scala | 16 ++++++++++++++-- .../org/apache/spark/mllib/linalg/Vectors.scala | 10 ++++------ project/MimaExcludes.scala | 4 ++++ 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 1e50a895a9a0..fdd1851ae550 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -50,7 +50,7 @@ private[classification] trait ProbabilisticClassifierParams * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassifier[ +abstract class ProbabilisticClassifier[ FeaturesType, E <: ProbabilisticClassifier[FeaturesType, E, M], M <: ProbabilisticClassificationModel[FeaturesType, M]] @@ -74,7 +74,7 @@ private[spark] abstract class ProbabilisticClassifier[ * @tparam M Concrete Model type */ @DeveloperApi -private[spark] abstract class ProbabilisticClassificationModel[ +abstract class ProbabilisticClassificationModel[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]] extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index ddd34a54503a..bd213e7362e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -19,11 +19,19 @@ package org.apache.spark.ml.util import java.util.UUID +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: + * * Trait for an object with an immutable unique ID that identifies itself and its derivatives. + * + * WARNING: There have not yet been final discussions on this API, so it may be broken in future + * releases. */ -private[spark] trait Identifiable { +@DeveloperApi +trait Identifiable { /** * An immutable unique ID for the object and its derivatives. @@ -33,7 +41,11 @@ private[spark] trait Identifiable { override def toString: String = uid } -private[spark] object Identifiable { +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object Identifiable { /** * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 86c461fa9163..df15d985c814 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.AlphaComponent import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow @@ -159,15 +159,13 @@ sealed trait Vector extends Serializable { } /** - * :: DeveloperApi :: + * :: AlphaComponent :: * * User-defined type for [[Vector]] which allows easy interaction with SQL * via [[org.apache.spark.sql.DataFrame]]. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@DeveloperApi -private[spark] class VectorUDT extends UserDefinedType[Vector] { +@AlphaComponent +class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { // type: 0 = sparse, 1 = dense diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 90261ca3d61a..784f83c10e02 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -178,6 +178,10 @@ object MimaExcludes { // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.mllib.linalg.VectorUDT.serialize") ) case v if v.startsWith("1.4") => From d7053bea985679c514b3add029631ea23e1730ce Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 20:44:40 -0700 Subject: [PATCH 0216/1824] [SPARK-9903] [MLLIB] skip local processing in PrefixSpan if there are no small prefixes There exists a chance that the prefixes keep growing to the maximum pattern length. Then the final local processing step becomes unnecessary. feynmanliang Author: Xiangrui Meng Closes #8136 from mengxr/SPARK-9903. --- .../apache/spark/mllib/fpm/PrefixSpan.scala | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index ad6715b52f33..dc4ae1d0b69e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -282,25 +282,30 @@ object PrefixSpan extends Logging { largePrefixes = newLargePrefixes } - // Switch to local processing. - val bcSmallPrefixes = sc.broadcast(smallPrefixes) - val distributedFreqPattern = postfixes.flatMap { postfix => - bcSmallPrefixes.value.values.map { prefix => - (prefix.id, postfix.project(prefix).compressed) - }.filter(_._2.nonEmpty) - }.groupByKey().flatMap { case (id, projPostfixes) => - val prefix = bcSmallPrefixes.value(id) - val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) - // TODO: We collect projected postfixes into memory. We should also compare the performance - // TODO: of keeping them on shuffle files. - localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => - (prefix.items ++ pattern, count) + var freqPatterns = sc.parallelize(localFreqPatterns, 1) + + val numSmallPrefixes = smallPrefixes.size + logInfo(s"number of small prefixes for local processing: $numSmallPrefixes") + if (numSmallPrefixes > 0) { + // Switch to local processing. + val bcSmallPrefixes = sc.broadcast(smallPrefixes) + val distributedFreqPattern = postfixes.flatMap { postfix => + bcSmallPrefixes.value.values.map { prefix => + (prefix.id, postfix.project(prefix).compressed) + }.filter(_._2.nonEmpty) + }.groupByKey().flatMap { case (id, projPostfixes) => + val prefix = bcSmallPrefixes.value(id) + val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) + // TODO: We collect projected postfixes into memory. We should also compare the performance + // TODO: of keeping them on shuffle files. + localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => + (prefix.items ++ pattern, count) + } } + // Union local frequent patterns and distributed ones. + freqPatterns = freqPatterns ++ distributedFreqPattern } - // Union local frequent patterns and distributed ones. - val freqPatterns = (sc.parallelize(localFreqPatterns, 1) ++ distributedFreqPattern) - .persist(StorageLevel.MEMORY_AND_DISK) freqPatterns } From 2fb4901b71cee65d40a43e61e3f4411c30cdefc3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 12 Aug 2015 20:59:38 -0700 Subject: [PATCH 0217/1824] [SPARK-9916] [BUILD] [SPARKR] removed left-over sparkr.zip copy/create commands from codebase sparkr.zip is now built by SparkSubmit on a need-to-build basis. cc shivaram Author: Burak Yavuz Closes #8147 from brkyvz/make-dist-fix. --- R/install-dev.bat | 5 ----- make-distribution.sh | 1 - 2 files changed, 6 deletions(-) diff --git a/R/install-dev.bat b/R/install-dev.bat index f32670b67de9..008a5c668bc4 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,8 +25,3 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ - -rem Zip the SparkR package so that it can be distributed to worker nodes on YARN -pushd %SPARK_HOME%\R\lib -%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR -popd diff --git a/make-distribution.sh b/make-distribution.sh index 4789b0e09cc8..247a81341e4a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -219,7 +219,6 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then mkdir -p "$DISTDIR"/R/lib cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib - cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib fi # Download and copy in tachyon, if requested From 2278219054314f1d31ffc358a59aa5067f9f5de9 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Aug 2015 21:24:15 -0700 Subject: [PATCH 0218/1824] [SPARK-9920] [SQL] The simpleString of TungstenAggregate does not show its output https://issues.apache.org/jira/browse/SPARK-9920 Taking `sqlContext.sql("select i, sum(j1) as sum from testAgg group by i").explain()` as an example, the output of our current master is ``` == Physical Plan == TungstenAggregate(key=[i#0], value=[(sum(cast(j1#1 as bigint)),mode=Final,isDistinct=false)] TungstenExchange hashpartitioning(i#0) TungstenAggregate(key=[i#0], value=[(sum(cast(j1#1 as bigint)),mode=Partial,isDistinct=false)] Scan ParquetRelation[file:/user/hive/warehouse/testagg][i#0,j1#1] ``` With this PR, the output will be ``` == Physical Plan == TungstenAggregate(key=[i#0], functions=[(sum(cast(j1#1 as bigint)),mode=Final,isDistinct=false)], output=[i#0,sum#18L]) TungstenExchange hashpartitioning(i#0) TungstenAggregate(key=[i#0], functions=[(sum(cast(j1#1 as bigint)),mode=Partial,isDistinct=false)], output=[i#0,currentSum#22L]) Scan ParquetRelation[file:/user/hive/warehouse/testagg][i#0,j1#1] ``` Author: Yin Huai Closes #8150 from yhuai/SPARK-9920. --- .../spark/sql/execution/aggregate/SortBasedAggregate.scala | 6 +++++- .../spark/sql/execution/aggregate/TungstenAggregate.scala | 7 ++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ab26f9c58aa2..f4c14a9b3556 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -108,6 +108,10 @@ case class SortBasedAggregate( override def simpleString: String = { val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions - s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + + val keyString = groupingExpressions.mkString("[", ",", "]") + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index c40ca973796a..99f51ba5b693 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -127,11 +127,12 @@ case class TungstenAggregate( testFallbackStartsAt match { case None => val keyString = groupingExpressions.mkString("[", ",", "]") - val valueString = allAggregateExpressions.mkString("[", ",", "]") - s"TungstenAggregate(key=$keyString, value=$valueString" + val functionString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"TungstenAggregate(key=$keyString, functions=$functionString, output=$outputString)" case Some(fallbackStartsAt) => s"TungstenAggregateWithControlledFallback $groupingExpressions " + - s"$allAggregateExpressions fallbackStartsAt=$fallbackStartsAt" + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } } From a8ab2634c1eee143a4deaf309204df8add727f9e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 12 Aug 2015 21:26:00 -0700 Subject: [PATCH 0219/1824] [SPARK-9832] [SQL] add a thread-safe lookup for BytesToBytseMap This patch add a thread-safe lookup for BytesToBytseMap, and use that in broadcasted HashedRelation. Author: Davies Liu Closes #8151 from davies/safeLookup. --- .../spark/unsafe/map/BytesToBytesMap.java | 30 ++++++++++++++----- .../sql/execution/joins/HashedRelation.scala | 6 ++-- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 87ed47e88c4e..5f3a4fcf4d58 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -17,25 +17,24 @@ package org.apache.spark.unsafe.map; -import java.lang.Override; -import java.lang.UnsupportedOperationException; +import javax.annotation.Nullable; import java.util.Iterator; import java.util.LinkedList; import java.util.List; -import javax.annotation.Nullable; - import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.unsafe.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.bitset.BitSet; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.memory.*; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -328,6 +327,20 @@ public Location lookup( Object keyBaseObject, long keyBaseOffset, int keyRowLengthBytes) { + safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc); + return loc; + } + + /** + * Looks up a key, and saves the result in provided `loc`. + * + * This is a thread-safe version of `lookup`, could be used by multiple threads. + */ + public void safeLookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes, + Location loc) { assert(bitset != null); assert(longArray != null); @@ -343,7 +356,8 @@ public Location lookup( } if (!bitset.isSet(pos)) { // This is a new key. - return loc.with(pos, hashcode, false); + loc.with(pos, hashcode, false); + return; } else { long stored = longArray.get(pos * 2 + 1); if ((int) (stored) == hashcode) { @@ -361,7 +375,7 @@ public Location lookup( keyRowLengthBytes ); if (areEqual) { - return loc; + return; } else { if (enablePerfMetrics) { numHashCollisions++; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index bb333b4d5ed1..ea02076b41a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -215,8 +215,10 @@ private[joins] final class UnsafeHashedRelation( if (binaryMap != null) { // Used in Broadcast join - val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes) + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc) if (loc.isDefined) { val buffer = CompactBuffer[UnsafeRow]() From 5fc058a1fc5d83ad53feec936475484aef3800b3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 21:33:38 -0700 Subject: [PATCH 0220/1824] [SPARK-9917] [ML] add getMin/getMax and doc for originalMin/origianlMax in MinMaxScaler hhbyyh Author: Xiangrui Meng Closes #8145 from mengxr/SPARK-9917. --- .../org/apache/spark/ml/feature/MinMaxScaler.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index b30adf3df48d..9a473dd23772 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -41,6 +41,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val min: DoubleParam = new DoubleParam(this, "min", "lower bound of the output feature range") + /** @group getParam */ + def getMin: Double = $(min) + /** * upper bound after transformation, shared by all features * Default: 1.0 @@ -49,6 +52,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H val max: DoubleParam = new DoubleParam(this, "max", "upper bound of the output feature range") + /** @group getParam */ + def getMax: Double = $(max) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType @@ -115,6 +121,9 @@ class MinMaxScaler(override val uid: String) * :: Experimental :: * Model fitted by [[MinMaxScaler]]. * + * @param originalMin min value for each original column during fitting + * @param originalMax max value for each original column during fitting + * * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). */ @Experimental @@ -136,7 +145,6 @@ class MinMaxScalerModel private[ml] ( /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def transform(dataset: DataFrame): DataFrame = { val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray val minArray = originalMin.toArray From df543892122342b97e5137b266959ba97589b3ef Mon Sep 17 00:00:00 2001 From: "shikai.tang" Date: Wed, 12 Aug 2015 21:53:15 -0700 Subject: [PATCH 0221/1824] [SPARK-8922] [DOCUMENTATION, MLLIB] Add @since tags to mllib.evaluation Author: shikai.tang Closes #7429 from mosessky/master. --- .../BinaryClassificationMetrics.scala | 32 ++++++++++++++++--- .../mllib/evaluation/MulticlassMetrics.scala | 9 ++++++ .../mllib/evaluation/MultilabelMetrics.scala | 4 +++ .../mllib/evaluation/RankingMetrics.scala | 4 +++ .../mllib/evaluation/RegressionMetrics.scala | 6 ++++ 5 files changed, 50 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index c1d1a224817e..486741edd6f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.DataFrame * of bins may not exactly equal numBins. The last bin in each partition may * be smaller as a result, meaning there may be an extra sample at * partition boundaries. + * @since 1.3.0 */ @Experimental class BinaryClassificationMetrics( @@ -51,6 +52,7 @@ class BinaryClassificationMetrics( /** * Defaults `numBins` to 0. + * @since 1.0.0 */ def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0) @@ -61,12 +63,18 @@ class BinaryClassificationMetrics( private[mllib] def this(scoreAndLabels: DataFrame) = this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1)))) - /** Unpersist intermediate RDDs used in the computation. */ + /** + * Unpersist intermediate RDDs used in the computation. + * @since 1.0.0 + */ def unpersist() { cumulativeCounts.unpersist() } - /** Returns thresholds in descending order. */ + /** + * Returns thresholds in descending order. + * @since 1.0.0 + */ def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) /** @@ -74,6 +82,7 @@ class BinaryClassificationMetrics( * which is an RDD of (false positive rate, true positive rate) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + * @since 1.0.0 */ def roc(): RDD[(Double, Double)] = { val rocCurve = createCurve(FalsePositiveRate, Recall) @@ -85,6 +94,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the receiver operating characteristic (ROC) curve. + * @since 1.0.0 */ def areaUnderROC(): Double = AreaUnderCurve.of(roc()) @@ -92,6 +102,7 @@ class BinaryClassificationMetrics( * Returns the precision-recall curve, which is an RDD of (recall, precision), * NOT (precision, recall), with (0.0, 1.0) prepended to it. * @see http://en.wikipedia.org/wiki/Precision_and_recall + * @since 1.0.0 */ def pr(): RDD[(Double, Double)] = { val prCurve = createCurve(Recall, Precision) @@ -102,6 +113,7 @@ class BinaryClassificationMetrics( /** * Computes the area under the precision-recall curve. + * @since 1.0.0 */ def areaUnderPR(): Double = AreaUnderCurve.of(pr()) @@ -110,16 +122,26 @@ class BinaryClassificationMetrics( * @param beta the beta factor in F-Measure computation. * @return an RDD of (threshold, F-Measure) pairs. * @see http://en.wikipedia.org/wiki/F1_score + * @since 1.0.0 */ def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) - /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ + /** + * Returns the (threshold, F-Measure) curve with beta = 1.0. + * @since 1.0.0 + */ def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) - /** Returns the (threshold, precision) curve. */ + /** + * Returns the (threshold, precision) curve. + * @since 1.0.0 + */ def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) - /** Returns the (threshold, recall) curve. */ + /** + * Returns the (threshold, recall) curve. + * @since 1.0.0 + */ def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) private lazy val ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 4628dc569091..dddfa3ea5b80 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.DataFrame * Evaluator for multiclass classification. * * @param predictionAndLabels an RDD of (prediction, label) pairs. + * @since 1.1.0 */ @Experimental class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { @@ -64,6 +65,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * predicted classes are in columns, * they are ordered by class label ascending, * as in "labels" + * @since 1.1.0 */ def confusionMatrix: Matrix = { val n = labels.size @@ -83,12 +85,14 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns true positive rate for a given label (category) * @param label the label. + * @since 1.1.0 */ def truePositiveRate(label: Double): Double = recall(label) /** * Returns false positive rate for a given label (category) * @param label the label. + * @since 1.1.0 */ def falsePositiveRate(label: Double): Double = { val fp = fpByClass.getOrElse(label, 0) @@ -98,6 +102,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns precision for a given label (category) * @param label the label. + * @since 1.1.0 */ def precision(label: Double): Double = { val tp = tpByClass(label) @@ -108,6 +113,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns recall for a given label (category) * @param label the label. + * @since 1.1.0 */ def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) @@ -115,6 +121,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns f-measure for a given label (category) * @param label the label. * @param beta the beta parameter. + * @since 1.1.0 */ def fMeasure(label: Double, beta: Double): Double = { val p = precision(label) @@ -126,6 +133,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns f1-measure for a given label (category) * @param label the label. + * @since 1.1.0 */ def fMeasure(label: Double): Double = fMeasure(label, 1.0) @@ -179,6 +187,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged f-measure * @param beta the beta parameter. + * @since 1.1.0 */ def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => fMeasure(category, beta) * count.toDouble / labelCount diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index bf6eb1d5bd2a..77cb1e09bdbb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.DataFrame * Evaluator for multilabel classification. * @param predictionAndLabels an RDD of (predictions, labels) pairs, * both are non-null Arrays, each with unique elements. + * @since 1.2.0 */ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { @@ -103,6 +104,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns precision for a given label (category) * @param label the label. + * @since 1.2.0 */ def precision(label: Double): Double = { val tp = tpPerClass(label) @@ -113,6 +115,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns recall for a given label (category) * @param label the label. + * @since 1.2.0 */ def recall(label: Double): Double = { val tp = tpPerClass(label) @@ -123,6 +126,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns f1-measure for a given label (category) * @param label the label. + * @since 1.2.0 */ def f1Measure(label: Double): Double = { val p = precision(label) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 5b5a2a1450f7..063fbed8cdee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD * Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance. * * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. + * @since 1.2.0 */ @Experimental class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) @@ -55,6 +56,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * * @param k the position to compute the truncated precision, must be positive * @return the average precision at the first k ranking positions + * @since 1.2.0 */ def precisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") @@ -124,6 +126,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * * @param k the position to compute the truncated ndcg, must be positive * @return the average ndcg at the first k ranking positions + * @since 1.2.0 */ def ndcgAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") @@ -162,6 +165,7 @@ object RankingMetrics { /** * Creates a [[RankingMetrics]] instance (for Java users). * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs + * @since 1.4.0 */ def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = { implicit val tag = JavaSparkContext.fakeClassTag[E] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 408847afa800..54dfd8c09949 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.DataFrame * Evaluator for regression. * * @param predictionAndObservations an RDD of (prediction, observation) pairs. + * @since 1.2.0 */ @Experimental class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { @@ -66,6 +67,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the variance explained by regression. * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] + * @since 1.2.0 */ def explainedVariance: Double = { SSreg / summary.count @@ -74,6 +76,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. + * @since 1.2.0 */ def meanAbsoluteError: Double = { summary.normL1(1) / summary.count @@ -82,6 +85,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. + * @since 1.2.0 */ def meanSquaredError: Double = { SSerr / summary.count @@ -90,6 +94,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. + * @since 1.2.0 */ def rootMeanSquaredError: Double = { math.sqrt(this.meanSquaredError) @@ -98,6 +103,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * @since 1.2.0 */ def r2: Double = { 1 - SSerr / SStot From d7eb371eb6369a34e58a09179efe058c4101de9e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 22:30:33 -0700 Subject: [PATCH 0222/1824] [SPARK-9914] [ML] define setters explicitly for Java and use setParam group in RFormula The problem with defining setters in the base class is that it doesn't return the correct type in Java. ericl Author: Xiangrui Meng Closes #8143 from mengxr/SPARK-9914 and squashes the following commits: d36c887 [Xiangrui Meng] remove setters from model a49021b [Xiangrui Meng] define setters explicitly for Java and use setParam group --- .../scala/org/apache/spark/ml/feature/RFormula.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d5360c9217ea..a752dacd72d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -33,11 +33,6 @@ import org.apache.spark.sql.types._ * Base trait for [[RFormula]] and [[RFormulaModel]]. */ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { - /** @group getParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - - /** @group getParam */ - def setLabelCol(value: String): this.type = set(labelCol, value) protected def hasLabelCol(schema: StructType): Boolean = { schema.map(_.name).contains($(labelCol)) @@ -71,6 +66,12 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** @group getParam */ def getFormula: String = $(formula) + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { require(isDefined(formula), "Formula must be defined first.") From d0b18919d16e6a2f19159516bd2767b60b595279 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 13 Aug 2015 13:33:39 +0800 Subject: [PATCH 0223/1824] [SPARK-9927] [SQL] Revert 8049 since it's pushing wrong filter down I made a mistake in #8049 by casting literal value to attribute's data type, which would cause simply truncate the literal value and push a wrong filter down. JIRA: https://issues.apache.org/jira/browse/SPARK-9927 Author: Yijie Shen Closes #8157 from yjshen/rever8049. --- .../datasources/DataSourceStrategy.scala | 30 ++-------------- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 35 ------------------- 3 files changed, 3 insertions(+), 64 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 9eea2b038253..2a4c40db8bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} +import org.apache.spark.sql.catalyst.{InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{TimestampType, DateType, StringType, StructType} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -343,17 +343,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * and convert them. */ protected[sql] def selectFilters(filters: Seq[Expression]) = { - import CatalystTypeConverters._ - def translate(predicate: Expression): Option[Filter] = predicate match { case expressions.EqualTo(a: Attribute, Literal(v, _)) => Some(sources.EqualTo(a.name, v)) case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) - case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) => - Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) => - Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => Some(sources.EqualNullSafe(a.name, v)) @@ -364,41 +358,21 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => Some(sources.LessThan(a.name, v)) - case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) => - Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) => - Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThan(a: Attribute, Literal(v, _)) => Some(sources.LessThan(a.name, v)) case expressions.LessThan(Literal(v, _), a: Attribute) => Some(sources.GreaterThan(a.name, v)) - case expressions.LessThan(Cast(a: Attribute, _), l: Literal) => - Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) => - Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.GreaterThanOrEqual(a.name, v)) case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.LessThanOrEqual(a.name, v)) - case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) => - Some(sources.GreaterThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) => - Some(sources.LessThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.LessThanOrEqual(a.name, v)) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.GreaterThanOrEqual(a.name, v)) - case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) => - Some(sources.LessThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) - case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) => - Some(sources.GreaterThanOrEqual(a.name, - convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.InSet(a: Attribute, set) => Some(sources.In(a.name, set.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 281943e23fcf..8eab6a0adccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -284,7 +284,7 @@ private[sql] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - val filterWhereClause: String = { + private val filterWhereClause: String = { val filterStrings = filters map compileFilter filter (_ != null) if (filterStrings.size > 0) { val sb = new StringBuilder("WHERE ") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index b9cfae51e809..e4dcf4c75d20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,8 +25,6 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD -import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -150,18 +148,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b DECIMAL(4, 0))"). - executeUpdate() - conn.prepareStatement("insert into test.decimals values (12345.67, 1234)").executeUpdate() - conn.prepareStatement("insert into test.decimals values (34567.89, 1428)").executeUpdate() - conn.commit() - sql( - s""" - |CREATE TEMPORARY TABLE decimals - |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', password 'testPass') - """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement( s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), @@ -458,25 +444,4 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } - - test("SPARK-9182: filters are not passed through to jdbc source") { - def checkPushedFilter(query: String, filterStr: String): Unit = { - val rddOpt = sql(query).queryExecution.executedPlan.collectFirst { - case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd - } - assert(rddOpt.isDefined) - val pushedFilterStr = rddOpt.get.filterWhereClause - assert(pushedFilterStr.contains(filterStr), - s"Expected to push [$filterStr], actually we pushed [$pushedFilterStr]") - } - - checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 'fred'") - checkPushedFilter("select * from inttypes where A > '15'", "A > 15") - checkPushedFilter("select * from inttypes where C <= 20", "C <= 20") - checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00") - checkPushedFilter("select * from decimals where A > 1000 AND A < 2000", - "A > 1000.00 AND A < 2000.00") - checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 2000.00 AND B > 20") - checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 1998-09-10") - } } From 68f99571492f67596b3656e9f076deeb96616f4a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 12 Aug 2015 23:04:59 -0700 Subject: [PATCH 0224/1824] [SPARK-9918] [MLLIB] remove runs from k-means and rename epsilon to tol This requires some discussion. I'm not sure whether `runs` is a useful parameter. It certainly complicates the implementation. We might want to optimize the k-means implementation with block matrix operations. In this case, having `runs` may not be worth the trade-off. Also it increases the communication cost in a single job, which might cause other issues. This PR also renames `epsilon` to `tol` to have consistent naming among algorithms. The Python constructor is updated to include all parameters. jkbradley yu-iskw Author: Xiangrui Meng Closes #8148 from mengxr/SPARK-9918 and squashes the following commits: 149b9e5 [Xiangrui Meng] fix constructor in Python and rename epsilon to tol 3cc15b3 [Xiangrui Meng] fix test and change initStep to initSteps in python a0a0274 [Xiangrui Meng] remove runs from k-means in the pipeline API --- .../apache/spark/ml/clustering/KMeans.scala | 51 +++------------ .../spark/ml/clustering/KMeansSuite.scala | 12 +--- python/pyspark/ml/clustering.py | 63 ++++--------------- 3 files changed, 26 insertions(+), 100 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index dc192add6ca1..47a18cdb31b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap} -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed} +import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} @@ -27,14 +27,13 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.util.Utils /** * Common params for KMeans and KMeansModel */ -private[clustering] trait KMeansParams - extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasTol { /** * Set the number of clusters to create (k). Must be > 1. Default: 2. @@ -45,31 +44,6 @@ private[clustering] trait KMeansParams /** @group getParam */ def getK: Int = $(k) - /** - * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm - * this many times with random starting conditions (configured by the initialization mode), then - * return the best clustering found over any run. Must be >= 1. Default: 1. - * @group param - */ - final val runs = new IntParam(this, "runs", - "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1) - - /** @group getParam */ - def getRuns: Int = $(runs) - - /** - * Param the distance threshold within which we've consider centers to have converged. - * If all centers move less than this Euclidean distance, we stop iterating one run. - * Must be >= 0.0. Default: 1e-4 - * @group param - */ - final val epsilon = new DoubleParam(this, "epsilon", - "distance threshold within which we've consider centers to have converge", - (value: Double) => value >= 0.0) - - /** @group getParam */ - def getEpsilon: Double = $(epsilon) - /** * Param for the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ @@ -136,9 +110,9 @@ class KMeansModel private[ml] ( /** * :: Experimental :: - * K-means clustering with support for multiple parallel runs and a k-means++ like initialization - * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, - * they are executed together with joint passes over the data for efficiency. + * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. + * + * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]] */ @Experimental class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { @@ -146,10 +120,9 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean setDefault( k -> 2, maxIter -> 20, - runs -> 1, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 5, - epsilon -> 1e-4) + tol -> 1e-4) override def copy(extra: ParamMap): KMeans = defaultCopy(extra) @@ -174,10 +147,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ - def setRuns(value: Int): this.type = set(runs, value) - - /** @group setParam */ - def setEpsilon(value: Double): this.type = set(epsilon, value) + def setTol(value: Double): this.type = set(tol, value) /** @group setParam */ def setSeed(value: Long): this.type = set(seed, value) @@ -191,8 +161,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean .setInitializationSteps($(initSteps)) .setMaxIterations($(maxIter)) .setSeed($(seed)) - .setEpsilon($(epsilon)) - .setRuns($(runs)) + .setEpsilon($(tol)) val parentModel = algo.run(rdd) val model = new KMeansModel(uid, parentModel) copyValues(model) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 1f15ac02f400..688b0e31f91d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -52,10 +52,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(kmeans.getFeaturesCol === "features") assert(kmeans.getPredictionCol === "prediction") assert(kmeans.getMaxIter === 20) - assert(kmeans.getRuns === 1) assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 5) - assert(kmeans.getEpsilon === 1e-4) + assert(kmeans.getTol === 1e-4) } test("set parameters") { @@ -64,21 +63,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setFeaturesCol("test_feature") .setPredictionCol("test_prediction") .setMaxIter(33) - .setRuns(7) .setInitMode(MLlibKMeans.RANDOM) .setInitSteps(3) .setSeed(123) - .setEpsilon(1e-3) + .setTol(1e-3) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") assert(kmeans.getPredictionCol === "test_prediction") assert(kmeans.getMaxIter === 33) - assert(kmeans.getRuns === 7) assert(kmeans.getInitMode === MLlibKMeans.RANDOM) assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) - assert(kmeans.getEpsilon === 1e-3) + assert(kmeans.getTol === 1e-3) } test("parameters validation") { @@ -91,9 +88,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } - intercept[IllegalArgumentException] { - new KMeans().setRuns(0) - } } test("fit & transform") { diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 48338713a29e..cb4c16e25a7a 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -19,7 +19,6 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * from pyspark.mllib.common import inherit_doc -from pyspark.mllib.linalg import _convert_to_vector __all__ = ['KMeans', 'KMeansModel'] @@ -35,7 +34,7 @@ def clusterCenters(self): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): +class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, @@ -45,7 +44,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] >>> df = sqlContext.createDataFrame(data, ["features"]) - >>> kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol("features") + >>> kmeans = KMeans(k=2, seed=1) >>> model = kmeans.fit(df) >>> centers = model.clusterCenters() >>> len(centers) @@ -60,10 +59,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): # a placeholder to make it appear in the generated doc k = Param(Params._dummy(), "k", "number of clusters to create") - epsilon = Param(Params._dummy(), "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - runs = Param(Params._dummy(), "runs", "number of runs of the algorithm to execute in parallel") initMode = Param(Params._dummy(), "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + @@ -71,21 +66,21 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode") @keyword_only - def __init__(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initStep=5): + def __init__(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): + """ + __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) + """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) self.k = Param(self, "k", "number of clusters to create") - self.epsilon = Param(self, "epsilon", - "distance threshold within which " + - "we've consider centers to have converged") - self.runs = Param(self, "runs", "number of runs of the algorithm to execute in parallel") - self.seed = Param(self, "seed", "random seed") self.initMode = Param(self, "initMode", "the initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + "to use a parallel variant of k-means++") self.initSteps = Param(self, "initSteps", "steps for k-means initialization mode") - self._setDefault(k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5) + self._setDefault(k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -93,9 +88,11 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only - def setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + def setParams(self, featuresCol="features", predictionCol="prediction", k=2, + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ - setParams(self, k=2, maxIter=20, runs=1, epsilon=1e-4, initMode="k-means||", initSteps=5): + setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ + initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None) Sets params for KMeans. """ @@ -119,40 +116,6 @@ def getK(self): """ return self.getOrDefault(self.k) - def setEpsilon(self, value): - """ - Sets the value of :py:attr:`epsilon`. - - >>> algo = KMeans().setEpsilon(1e-5) - >>> abs(algo.getEpsilon() - 1e-5) < 1e-5 - True - """ - self._paramMap[self.epsilon] = value - return self - - def getEpsilon(self): - """ - Gets the value of `epsilon` - """ - return self.getOrDefault(self.epsilon) - - def setRuns(self, value): - """ - Sets the value of :py:attr:`runs`. - - >>> algo = KMeans().setRuns(10) - >>> algo.getRuns() - 10 - """ - self._paramMap[self.runs] = value - return self - - def getRuns(self): - """ - Gets the value of `runs` - """ - return self.getOrDefault(self.runs) - def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. From 84a27916a62980c8fcb0977c3a7fdb73c0bd5812 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 13 Aug 2015 15:08:57 +0800 Subject: [PATCH 0225/1824] [SPARK-9885] [SQL] Also pass barrierPrefixes and sharedPrefixes to IsolatedClientLoader when hiveMetastoreJars is set to maven. https://issues.apache.org/jira/browse/SPARK-9885 cc marmbrus liancheng Author: Yin Huai Closes #8158 from yhuai/classloaderMaven. --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 6 +++++- .../spark/sql/hive/client/IsolatedClientLoader.scala | 11 +++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index f17177a771c3..17762649fd70 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -231,7 +231,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // TODO: Support for loading the jars from an already downloaded location. logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig) + IsolatedClientLoader.forVersion( + version = hiveMetastoreVersion, + config = allConfig, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } else { // Convert to files and expand any directories. val jars = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index a7d5a991948d..785603750841 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -42,11 +42,18 @@ private[hive] object IsolatedClientLoader { def forVersion( version: String, config: Map[String, String] = Map.empty, - ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { + ivyPath: Option[String] = None, + sharedPrefixes: Seq[String] = Seq.empty, + barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion, ivyPath)) - new IsolatedClientLoader(hiveVersion(version), files, config) + new IsolatedClientLoader( + version = hiveVersion(version), + execJars = files, + config = config, + sharedPrefixes = sharedPrefixes, + barrierPrefixes = barrierPrefixes) } def hiveVersion(version: String): HiveVersion = version match { From 69930310115501f0de094fe6f5c6c60dade342bd Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 13 Aug 2015 16:16:50 +0800 Subject: [PATCH 0226/1824] [SPARK-9757] [SQL] Fixes persistence of Parquet relation with decimal column PR #7967 enables us to save data source relations to metastore in Hive compatible format when possible. But it fails to persist Parquet relations with decimal column(s) to Hive metastore of versions lower than 1.2.0. This is because `ParquetHiveSerDe` in Hive versions prior to 1.2.0 doesn't support decimal. This PR checks for this case and falls back to Spark SQL specific metastore table format. Author: Yin Huai Author: Cheng Lian Closes #8130 from liancheng/spark-9757/old-hive-parquet-decimal. --- .../apache/spark/sql/types/ArrayType.scala | 6 +- .../org/apache/spark/sql/types/DataType.scala | 5 ++ .../org/apache/spark/sql/types/MapType.scala | 6 +- .../apache/spark/sql/types/StructType.scala | 8 ++- .../spark/sql/types/DataTypeSuite.scala | 24 +++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 39 ++++++++--- .../sql/hive/client/ClientInterface.scala | 3 + .../spark/sql/hive/client/ClientWrapper.scala | 2 +- .../spark/sql/hive/client/package.scala | 2 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 17 +++-- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 68 +++++++++++++++++-- 11 files changed, 150 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 5094058164b2..5770f59b5307 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -75,6 +75,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" - private[spark] override def asNullable: ArrayType = + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || elementType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index f4428c2e8b20..7bcd623b3f33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -77,6 +77,11 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType + /** + * Returns true if any `DataType` of this DataType tree satisfies the given function `f`. + */ + private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ac34b642827c..00461e529ca0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,8 +62,12 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" - private[spark] override def asNullable: MapType = + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9cbc207538d4..d8968ef80639 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} /** @@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] def merge(that: StructType): StructType = StructType.merge(this, that).asInstanceOf[StructType] - private[spark] override def asNullable: StructType = { + override private[spark] def asNullable: StructType = { val newFields = fields.map { case StructField(name, dataType, nullable, metadata) => StructField(name, dataType.asNullable, nullable = true, metadata) @@ -301,6 +301,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || fields.exists(field => field.dataType.existsRecursively(f)) + } + private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 88b221cd81d7..706ecd29d135 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite { } } + test("existsRecursively") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + assert(struct.existsRecursively(_.isInstanceOf[LongType])) + assert(struct.existsRecursively(_.isInstanceOf[StructType])) + assert(!struct.existsRecursively(_.isInstanceOf[IntegerType])) + + val mapType = MapType(struct, StringType) + assert(mapType.existsRecursively(_.isInstanceOf[LongType])) + assert(mapType.existsRecursively(_.isInstanceOf[StructType])) + assert(mapType.existsRecursively(_.isInstanceOf[StringType])) + assert(mapType.existsRecursively(_.isInstanceOf[MapType])) + assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType])) + + val arrayType = ArrayType(mapType) + assert(arrayType.existsRecursively(_.isInstanceOf[LongType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StructType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StringType])) + assert(arrayType.existsRecursively(_.isInstanceOf[MapType])) + assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType])) + assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5e5497837a39..6770462bb0ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -33,15 +33,14 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.{FileRelation, datasources} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} @@ -86,9 +85,9 @@ private[hive] object HiveSerDe { serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) val key = source.toLowerCase match { - case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet" - case _ if source.startsWith("org.apache.spark.sql.orc") => "orc" - case _ => source.toLowerCase + case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" + case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s => s } serdeMap.get(key) @@ -309,11 +308,31 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val hiveTable = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - logInfo { - "Persisting data source relation with a single input path into Hive metastore in Hive " + - s"compatible format. Input path: ${relation.paths.head}" + // Hive ParquetSerDe doesn't support decimal type until 1.2.0. + val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) + val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) + + val hiveParquetSupportsDecimal = client.version match { + case org.apache.spark.sql.hive.client.hive.v1_2 => true + case _ => false + } + + if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { + // If Hive version is below 1.2.0, we cannot save Hive compatible schema to + // metastore when the file format is Parquet and the schema has DecimalType. + logWarning { + "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + + "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + + s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." + } + newSparkSQLSpecificMetastoreTable() + } else { + logInfo { + "Persisting data source relation with a single input path into Hive metastore in " + + s"Hive compatible format. Input path: ${relation.paths.head}" + } + newHiveCompatibleMetastoreTable(relation, serde) } - newHiveCompatibleMetastoreTable(relation, serde) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => logWarning { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index a82e152dcda2..3811c152a7ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -88,6 +88,9 @@ private[hive] case class HiveTable( */ private[hive] trait ClientInterface { + /** Returns the Hive Version of this client. */ + def version: HiveVersion + /** Returns the configuration for the given key in the current session. */ def getConf(key: String, defaultValue: String): String diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 3d05b583cf9e..f49c97de8ff4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -58,7 +58,7 @@ import org.apache.spark.util.{CircularBuffer, Utils} * this ClientWrapper. */ private[hive] class ClientWrapper( - version: HiveVersion, + override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader) extends ClientInterface diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 0503691a4424..b1b8439efa01 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -25,7 +25,7 @@ package object client { val exclusions: Seq[String] = Nil) // scalastyle:off - private[client] object hive { + private[hive] object hive { case object v12 extends HiveVersion("0.12.0") case object v13 extends HiveVersion("0.13.1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 332c3ec0c28b..59e65ff97b8e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable} +import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.sources.DataSourceTest import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.{Logging, SparkFunSuite} @@ -55,7 +55,10 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { override val sqlContext = TestHive - private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) + private val testDF = range(1, 3).select( + ('id + 0.1) cast DecimalType(10, 3) as 'd1, + 'id cast StringType as 'd2 + ).coalesce(1) Seq( "parquet" -> ( @@ -88,10 +91,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -117,10 +120,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 1e1972d1ac35..0c2964611446 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -20,16 +20,18 @@ package org.apache.spark.sql.hive import java.io.File import scala.collection.mutable.ArrayBuffer -import scala.sys.process.{ProcessLogger, Process} +import scala.sys.process.{Process, ProcessLogger} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.types.DecimalType import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.Matchers -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ /** * This suite tests spark-submit with applications using HiveContext. @@ -50,8 +52,8 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), @@ -91,6 +93,16 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-9757 Persist Parquet relation with decimal column") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_9757.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -213,7 +225,7 @@ object SparkSQLConfTest extends Logging { // before spark.sql.hive.metastore.jars get set, we will see the following exception: // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only // be used when hive execution version == hive metastore version. - // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. val conf = new SparkConf() { override def getAll: Array[(String, String)] = { @@ -239,3 +251,45 @@ object SparkSQLConfTest extends Logging { sc.stop() } } + +object SPARK_9757 extends QueryTest with Logging { + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.sql.hive.metastore.version", "0.13.1") + .set("spark.sql.hive.metastore.jars", "maven")) + + val hiveContext = new TestHiveContext(sparkContext) + import hiveContext.implicits._ + import org.apache.spark.sql.functions._ + + val dir = Utils.createTempDir() + dir.delete() + + try { + { + val df = + hiveContext + .range(10) + .select(('id + 0.1) cast DecimalType(10, 3) as 'dec) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + + { + val df = + hiveContext + .range(10) + .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + } finally { + dir.delete() + hiveContext.sql("DROP TABLE t") + sparkContext.stop() + } + } +} From 2932e25da4532de9e86b01d08bce0cb680874e70 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 13 Aug 2015 09:17:19 -0700 Subject: [PATCH 0227/1824] [SPARK-9073] [ML] spark.ml Models copy() should call setParent when there is a parent Copied ML models must have the same parent of original ones Author: lewuathe Author: Lewuathe Closes #7447 from Lewuathe/SPARK-9073. --- .../examples/ml/JavaDeveloperApiExample.java | 3 +- .../examples/ml/DeveloperApiExample.scala | 2 +- .../scala/org/apache/spark/ml/Pipeline.scala | 2 +- .../DecisionTreeClassifier.scala | 1 + .../ml/classification/GBTClassifier.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../spark/ml/classification/OneVsRest.scala | 2 +- .../RandomForestClassifier.scala | 1 + .../apache/spark/ml/feature/Bucketizer.scala | 4 ++- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../spark/ml/tuning/CrossValidator.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 3 ++ .../DecisionTreeClassifierSuite.scala | 4 +++ .../classification/GBTClassifierSuite.scala | 4 +++ .../LogisticRegressionSuite.scala | 4 +++ .../ml/classification/OneVsRestSuite.scala | 6 +++- .../RandomForestClassifierSuite.scala | 4 +++ .../spark/ml/feature/BucketizerSuite.scala | 1 + .../spark/ml/feature/MinMaxScalerSuite.scala | 4 +++ .../apache/spark/ml/feature/PCASuite.scala | 4 +++ .../spark/ml/feature/StringIndexerSuite.scala | 5 ++++ .../spark/ml/feature/VectorIndexerSuite.scala | 5 ++++ .../spark/ml/feature/Word2VecSuite.scala | 4 +++ .../spark/ml/recommendation/ALSSuite.scala | 4 +++ .../DecisionTreeRegressorSuite.scala | 11 +++++++ .../ml/regression/GBTRegressorSuite.scala | 5 ++++ .../ml/regression/LinearRegressionSuite.scala | 5 ++++ .../RandomForestRegressorSuite.scala | 7 ++++- .../spark/ml/tuning/CrossValidatorSuite.scala | 5 ++++ .../apache/spark/ml/util/MLTestingUtils.scala | 30 +++++++++++++++++++ 41 files changed, 138 insertions(+), 22 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 9df26ffca577..3f1fe900b000 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -230,6 +230,7 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + .setParent(parent()); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 78f31b4ffe56..340c3559b15e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -179,7 +179,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) } } // scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index aef2c019d287..a3e59401c5cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -198,6 +198,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages.map(_.copy(extra))) + new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 29598f3f05c2..6f70b96b17ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -141,6 +141,7 @@ final class DecisionTreeClassificationModel private[ml] ( override def copy(extra: ParamMap): DecisionTreeClassificationModel = { copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c3891a959926..3073a2a61ce8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -196,7 +196,7 @@ final class GBTClassificationModel( } override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 5bcd7117b668..21fbe38ca823 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -468,7 +468,7 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1741f19dc911..1132d8046df6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -138,7 +138,7 @@ final class OneVsRestModel private[ml] ( override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 156050aaf7a4..11a6d7246833 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -189,6 +189,7 @@ final class RandomForestClassificationModel private[ml] ( override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 67e4785bc355..cfca494dcf46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -90,7 +90,9 @@ final class Bucketizer(override val uid: String) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } - override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) + override def copy(extra: ParamMap): Bucketizer = { + defaultCopy[Bucketizer](extra).setParent(parent) + } } private[feature] object Bucketizer { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index ecde80810580..938447447a0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -114,6 +114,6 @@ class IDFModel private[ml] ( override def copy(extra: ParamMap): IDFModel = { val copied = new IDFModel(uid, idfModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 9a473dd23772..1b494ec8b172 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -173,6 +173,6 @@ class MinMaxScalerModel private[ml] ( override def copy(extra: ParamMap): MinMaxScalerModel = { val copied = new MinMaxScalerModel(uid, originalMin, originalMax) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 2d3bb680cf30..539084704b65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -125,6 +125,6 @@ class PCAModel private[ml] ( override def copy(extra: ParamMap): PCAModel = { val copied = new PCAModel(uid, pcaModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72b545e5db3e..f6d0b0c0e9e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -136,6 +136,6 @@ class StandardScalerModel private[ml] ( override def copy(extra: ParamMap): StandardScalerModel = { val copied = new StandardScalerModel(uid, scaler) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index e4485eb03840..9e4b0f0add61 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -168,7 +168,7 @@ class StringIndexerModel private[ml] ( override def copy(extra: ParamMap): StringIndexerModel = { val copied = new StringIndexerModel(uid, labels) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index c73bdccdef5f..6875aefe065b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -405,6 +405,6 @@ class VectorIndexerModel private[ml] ( override def copy(extra: ParamMap): VectorIndexerModel = { val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 29acc3eb5865..5af775a4159a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -221,6 +221,6 @@ class Word2VecModel private[ml] ( override def copy(extra: ParamMap): Word2VecModel = { val copied = new Word2VecModel(uid, wordVectors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2e44cd4cc6a2..7db8ad8d2791 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -219,7 +219,7 @@ class ALSModel private[ml] ( override def copy(extra: ParamMap): ALSModel = { val copied = new ALSModel(uid, rank, userFactors, itemFactors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index dc94a1401454..a2bcd67401d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -114,7 +114,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra) + copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 5633bc320273..b66e61f37dd5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -185,7 +185,7 @@ final class GBTRegressionModel( } override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 92d819bad865..884003eb3852 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -312,7 +312,7 @@ class LinearRegressionModel private[ml] ( override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel + newModel.setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index db75c0d26392..2f36da371f57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -151,7 +151,7 @@ final class RandomForestRegressionModel private[ml] ( } override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra) + copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index f979319cc4b5..4792eb0f0a28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,6 +160,6 @@ class CrossValidatorModel private[ml] ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], avgMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 63d2fa31c749..1f2c9b75b617 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.sql.DataFrame class PipelineSuite extends SparkFunSuite { @@ -65,6 +66,8 @@ class PipelineSuite extends SparkFunSuite { .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) + MLTestingUtils.checkCopy(pipelineModel) + assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) assert(pipelineModel.stages(1).eq(transformer1)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c7bbf1ce07a2..4b7c5d3f23d2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -244,6 +245,9 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val newTree = dt.fit(newData) + // copied model must have the same parent. + MLTestingUtils.checkCopy(newTree) + val predictions = newTree.transform(newData) .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index d4b5896c12c0..e3909bccaa5c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -92,6 +93,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { .setCheckpointInterval(2) val model = gbt.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + sc.checkpointDir = None Utils.deleteRecursively(tempDir) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e354e161c6de..cce39f382f73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -135,6 +136,9 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { lr.setFitIntercept(false) val model = lr.fit(dataset) assert(model.intercept === 0.0) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index bd8e819f6926..977f0e0b70c1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils} import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -70,6 +70,10 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(ovaModel) + assert(ovaModel.models.size === numClasses) val transformedDataset = ovaModel.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 6ca4b5aa5fde..b4403ec30049 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -135,6 +136,9 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val model = rf.fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val predictions = model.transform(df) .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index ec85e0d151e0..0eba34fda622 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c452054bec92..c04dda41eea3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} @@ -51,6 +52,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { .foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1.equals(vector2), "Transformed vector is different with expected.") } + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("MinMaxScaler arguments max must be larger than min") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index d0ae36b28c7a..30c500f87a76 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -56,6 +57,9 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { .setK(3) .fit(df) + // copied model must have the same parent. + MLTestingUtils.checkCopy(pca) + pca.transform(df).select("pca_features", "expected").collect().foreach { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b111036087e6..2d24914cb91f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.util.MLlibTestSparkContext class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -38,6 +39,10 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(indexer) + val transformed = indexer.transform(df) val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 03120c828ca9..8cb0a2cf14d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -109,6 +110,10 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L test("Throws error when given RDDs with different size vectors") { val vectorIndexer = getIndexer val model = vectorIndexer.fit(densePoints1) // vectors of length 3 + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work intercept[SparkException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index adcda0e623b2..a2e46f202995 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -62,6 +63,9 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(42L) .fit(docDF) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(docDF).select("result", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2e5cfe7027eb..eadc80e0e62b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -28,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -374,6 +375,9 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { } logInfo(s"Test RMSE is $rmse.") assert(rmse < targetRMSE) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) } test("exact rank-1 matrix") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 33aa9d0d6234..b092bcd6a7e8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -61,6 +62,16 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } + test("copied model must have the same parent") { + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8).fit(df) + MLTestingUtils.checkCopy(model) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dbdce0c9dea5..a68197b59193 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -82,6 +83,9 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { .setMaxDepth(2) .setMaxIter(2) val model = gbt.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) val preds = model.transform(df) val predictions = preds.select("prediction").map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) @@ -104,6 +108,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.checkpointDir = None Utils.deleteRecursively(tempDir) + } // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 21ad8225bd9f..2aaee71ecc73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -72,6 +73,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lir.getFitIntercept) assert(lir.getStandardization) val model = lir.fit(dataset) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + model.transform(dataset) .select("label", "prediction") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 992ce9562434..7b1b3f11481d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -91,7 +92,11 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val categoricalFeatures = Map.empty[Int, Int] val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) - val importances = rf.fit(df).featureImportances + val model = rf.fit(df) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + val importances = model.featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index db64511a7605..aaca08bb61a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} @@ -53,6 +54,10 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { .setEvaluator(eval) .setNumFolds(3) val cvModel = cv.fit(dataset) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala new file mode 100644 index 000000000000..d290cc9b06e7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.Model +import org.apache.spark.ml.param.ParamMap + +object MLTestingUtils { + def checkCopy(model: Model[_]): Unit = { + val copied = model.copy(ParamMap.empty) + .asInstanceOf[Model[_]] + assert(copied.parent.uid == model.parent.uid) + assert(copied.parent == model.parent) + } +} From 7a539ef3b1792764f866fa88c84c78ad59903f21 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Thu, 13 Aug 2015 09:18:39 -0700 Subject: [PATCH 0228/1824] [SPARK-8965] [DOCS] Add ml-guide Python Example: Estimator, Transformer, and Param Added ml-guide Python Example: Estimator, Transformer, and Param /docs/_site/ml-guide.html Author: Rosstin Closes #8081 from Rosstin/SPARK-8965. --- docs/ml-guide.md | 68 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b6ca50e98db0..a03ab4356a41 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -355,6 +355,74 @@ jsc.stop(); {% endhighlight %} +
    +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params +from pyspark.sql import Row, SQLContext + +sc = SparkContext(appName="SimpleParamsExample") +sqlContext = SQLContext(sc) + +# Prepare training data. +# We use LabeledPoint. +# Spark SQL can convert RDDs of LabeledPoints into DataFrames. +training = sc.parallelize([LabeledPoint(1.0, [0.0, 1.1, 0.1]), + LabeledPoint(0.0, [2.0, 1.0, -1.0]), + LabeledPoint(0.0, [2.0, 1.3, 1.0]), + LabeledPoint(1.0, [0.0, 1.2, -0.5])]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training.toDF()) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training.toDF(), paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sc.parallelize([LabeledPoint(1.0, [-1.0, 1.5, 1.3]), + LabeledPoint(0.0, [ 3.0, 2.0, -0.1]), + LabeledPoint(1.0, [ 0.0, 2.2, -1.5])]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test.toDF()) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + +sc.stop() +{% endhighlight %} +
    + ## Example: Pipeline From 4b70798c96b0a784b85fda461426ec60f609be12 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 13 Aug 2015 09:31:14 -0700 Subject: [PATCH 0229/1824] [MINOR] [ML] change MultilayerPerceptronClassifierModel to MultilayerPerceptronClassificationModel To follow the naming rule of ML, change `MultilayerPerceptronClassifierModel` to `MultilayerPerceptronClassificationModel` like `DecisionTreeClassificationModel`, `GBTClassificationModel` and so on. Author: Yanbo Liang Closes #8164 from yanboliang/mlp-name. --- .../MultilayerPerceptronClassifier.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 8cd2103d7d5e..c15456188658 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -131,7 +131,7 @@ private object LabelConverter { */ @Experimental class MultilayerPerceptronClassifier(override val uid: String) - extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] with MultilayerPerceptronParams { def this() = this(Identifiable.randomUID("mlpc")) @@ -146,7 +146,7 @@ class MultilayerPerceptronClassifier(override val uid: String) * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { val myLayers = $(layers) val labels = myLayers.last val lpData = extractLabeledPoints(dataset) @@ -156,13 +156,13 @@ class MultilayerPerceptronClassifier(override val uid: String) FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) FeedForwardTrainer.setStackSize($(blockSize)) val mlpModel = FeedForwardTrainer.train(data) - new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights()) } } /** * :: Experimental :: - * Classifier model based on the Multilayer Perceptron. + * Classification model based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. * @param uid uid * @param layers array of layer sizes including input and output layers @@ -170,11 +170,11 @@ class MultilayerPerceptronClassifier(override val uid: String) * @return prediction model */ @Experimental -class MultilayerPerceptronClassifierModel private[ml] ( +class MultilayerPerceptronClassificationModel private[ml] ( override val uid: String, layers: Array[Int], weights: Vector) - extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] + extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) @@ -187,7 +187,7 @@ class MultilayerPerceptronClassifierModel private[ml] ( LabelConverter.decodeLabel(mlpModel.predict(features)) } - override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { - copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { + copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } } From 65fec798ce52ca6b8b0fe14b78a16712778ad04c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Aug 2015 10:16:40 -0700 Subject: [PATCH 0230/1824] [MINOR] [DOC] fix mllib pydoc warnings Switch to correct Sphinx syntax. MechCoder Author: Xiangrui Meng Closes #8169 from mengxr/mllib-pydoc-fix. --- python/pyspark/mllib/regression.py | 14 ++++++++++---- python/pyspark/mllib/util.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 5b7afc15ddfb..41946e3674fb 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -207,8 +207,10 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, Train a linear regression model using Stochastic Gradient Descent (SGD). This solves the least squares regression formulation - f(weights) = 1/n ||A weights-y||^2^ - (which is the mean squared error). + + f(weights) = 1/(2n) ||A weights - y||^2, + + which is the mean squared error. Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -334,7 +336,9 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, Stochastic Gradient Descent. This solves the l1-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam ||weights||_1 + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam ||weights||_1. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. @@ -451,7 +455,9 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, Stochastic Gradient Descent. This solves the l2-regularized least squares regression formulation - f(weights) = 1/2n ||A weights-y||^2^ + regParam/2 ||weights||^2^ + + f(weights) = 1/(2n) ||A weights - y||^2 + regParam/2 ||weights||^2. + Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with its corresponding right hand side label y. See also the documentation for the precise formulation. diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 916de2d6fcdb..10a1e4b3eb0f 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -300,6 +300,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, :param: seed Random Seed :param: eps Used to scale the noise. If eps is set high, the amount of gaussian noise added is more. + Returns a list of LabeledPoints of length nPoints """ weights = [float(weight) for weight in weights] From 8815ba2f674dbb18eb499216df9942b411e10daa Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 13 Aug 2015 11:31:10 -0700 Subject: [PATCH 0231/1824] [SPARK-9649] Fix MasterSuite, third time's a charm This particular test did not load the default configurations so it continued to start the REST server, which causes port bind exceptions. --- .../test/scala/org/apache/spark/deploy/master/MasterSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 20d0201a364a..242bf4b5566e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -40,6 +40,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva conf.set("spark.deploy.recoveryMode", "CUSTOM") conf.set("spark.deploy.recoveryMode.factory", classOf[CustomRecoveryModeFactory].getCanonicalName) + conf.set("spark.master.rest.enabled", "false") val instantiationAttempts = CustomRecoveryModeFactory.instantiationAttempts From 864de8eaf4b6ad5c9099f6f29e251c56b029f631 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 13 Aug 2015 13:42:35 -0700 Subject: [PATCH 0232/1824] [SPARK-9661] [MLLIB] [ML] Java compatibility I skimmed through the docs for various instance of Object and replaced them with Java compaible versions of the same. 1. Some methods in LDAModel. 2. runMiniBatchSGD 3. kolmogorovSmirnovTest Author: MechCoder Closes #8126 from MechCoder/java_incop. --- .../spark/mllib/clustering/LDAModel.scala | 27 +++++++++++++++++-- .../apache/spark/mllib/stat/Statistics.scala | 16 ++++++++++- .../spark/mllib/clustering/JavaLDASuite.java | 24 +++++++++++++++++ .../spark/mllib/stat/JavaStatisticsSuite.java | 22 +++++++++++++++ .../spark/mllib/clustering/LDASuite.scala | 13 +++++++++ 5 files changed, 99 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 5dc637ebdc13..f31949f13a4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -26,7 +26,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -228,6 +228,11 @@ class LocalLDAModel private[clustering] ( docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) + /** Java-friendly version of [[logLikelihood]] */ + def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + /** * Calculate an upper bound bound on perplexity. (Lower is better.) * See Equation (16) in original Online LDA paper. @@ -242,6 +247,11 @@ class LocalLDAModel private[clustering] ( -logLikelihood(documents) / corpusTokenCount } + /** Java-friendly version of [[logPerplexity]] */ + def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { + logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } + /** * Estimate the variational likelihood bound of from `documents`: * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)] @@ -341,8 +351,14 @@ class LocalLDAModel private[clustering] ( } } -} + /** Java-friendly version of [[topicDistributions]] */ + def topicDistributions( + documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = { + val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + JavaPairRDD.fromRDD(distributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) + } +} @Experimental object LocalLDAModel extends Loader[LocalLDAModel] { @@ -657,6 +673,13 @@ class DistributedLDAModel private[clustering] ( } } + /** Java-friendly version of [[topTopicsPerDocument]] */ + def javaTopTopicsPerDocument( + k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[java.lang.Double])] = { + val topics = topTopicsPerDocument(k) + topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[java.lang.Double])]].toJavaRDD() + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f84502919e38..24fe48cb8f71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.regression.LabeledPoint @@ -178,6 +178,9 @@ object Statistics { ChiSqTest.chiSquaredFeatures(data) } + /** Java-friendly version of [[chiSqTest()]] */ + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = chiSqTest(data.rdd) + /** * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a * continuous distribution. By comparing the largest difference between the empirical cumulative @@ -212,4 +215,15 @@ object Statistics { : KolmogorovSmirnovTestResult = { KolmogorovSmirnovTest.testOneSample(data, distName, params: _*) } + + /** Java-friendly version of [[kolmogorovSmirnovTest()]] */ + @varargs + def kolmogorovSmirnovTest( + data: JavaDoubleRDD, + distName: String, + params: java.lang.Double*): KolmogorovSmirnovTestResult = { + val javaParams = params.asInstanceOf[Seq[Double]] + KolmogorovSmirnovTest.testOneSample(data.rdd.asInstanceOf[RDD[Double]], + distName, javaParams: _*) + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index d272a42c8576..427be9430d82 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -124,6 +124,10 @@ public Boolean call(Tuple2 tuple2) { } }); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); + + // Check: javaTopTopicsPerDocuments + JavaRDD> topTopics = + model.javaTopTopicsPerDocument(3); } @Test @@ -160,11 +164,31 @@ public void OnlineOptimizerCompatibility() { assertEquals(roundedLocalTopicSummary.length, k); } + @Test + public void localLdaMethods() { + JavaRDD> docs = sc.parallelize(toyData, 2); + JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs); + + // check: topicDistributions + assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count()); + + // check: logPerplexity + double logPerplexity = toyModel.logPerplexity(pairedDocs); + + // check: logLikelihood. + ArrayList> docsSingleWord = new ArrayList>(); + docsSingleWord.add(new Tuple2(Long.valueOf(0), Vectors.dense(1.0, 0.0, 0.0))); + JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + double logLikelihood = toyModel.logLikelihood(single); + } + private static int tinyK = LDASuite$.MODULE$.tinyK(); private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); private static Tuple2[] tinyTopicDescription = LDASuite$.MODULE$.tinyTopicDescription(); private JavaPairRDD corpus; + private LocalLDAModel toyModel = LDASuite$.MODULE$.toyModel(); + private ArrayList> toyData = LDASuite$.MODULE$.javaToyData(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 62f7f26b7c98..eb4e3698624b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -27,7 +27,12 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.ChiSqTestResult; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; @@ -53,4 +58,21 @@ public void testCorr() { // Check default method assertEquals(corr1, corr2); } + + @Test + public void kolmogorovSmirnovTest() { + JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0)); + KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); + KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( + data, "norm", 0.0, 1.0); + } + + @Test + public void chiSqTest() { + JavaRDD data = sc.parallelize(Lists.newArrayList( + new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), + new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), + new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); + ChiSqTestResult[] testResults = Statistics.chiSqTest(data); + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index ce6a8eb8e8c4..926185e90bcf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import java.util.{ArrayList => JArrayList} + import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} import org.apache.spark.SparkFunSuite @@ -575,6 +577,17 @@ private[clustering] object LDASuite { Vectors.sparse(6, Array(4, 5), Array(1, 1)) ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + /** Used in the Java Test Suite */ + def javaToyData: JArrayList[(java.lang.Long, Vector)] = { + val javaData = new JArrayList[(java.lang.Long, Vector)] + var i = 0 + while (i < toyData.size) { + javaData.add((toyData(i)._1, toyData(i)._2)) + i += 1 + } + javaData + } + def toyModel: LocalLDAModel = { val k = 2 val vocabSize = 6 From a8d2f4c5f92210a09c846711bd7cc89a43e2edd2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 14:03:55 -0700 Subject: [PATCH 0233/1824] [SPARK-9942] [PYSPARK] [SQL] ignore exceptions while try to import pandas If pandas is broken (can't be imported, raise other exceptions other than ImportError), pyspark can't be imported, we should ignore all the exceptions. Author: Davies Liu Closes #8173 from davies/fix_pandas. --- python/pyspark/sql/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 917de24f3536..0ef46c44644a 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -39,7 +39,7 @@ try: import pandas has_pandas = True -except ImportError: +except Exception: has_pandas = False __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] From c2520f501a200cf794bbe5dc9c385100f518d020 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 13 Aug 2015 16:07:03 -0700 Subject: [PATCH 0234/1824] [SPARK-9935] [SQL] EqualNotNull not processed in ORC https://issues.apache.org/jira/browse/SPARK-9935 Author: hyukjinkwon Closes #8163 from HyukjinKwon/master. --- .../scala/org/apache/spark/sql/hive/orc/OrcFilters.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 86142e5d66f3..b3d9f7f71a27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -107,6 +107,11 @@ private[orc] object OrcFilters extends Logging { .filter(isSearchableLiteral) .map(builder.equals(attribute, _)) + case EqualNullSafe(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.nullSafeEquals(attribute, _)) + case LessThan(attribute, value) => Option(value) .filter(isSearchableLiteral) From 6c5858bc65c8a8602422b46bfa9cf0a1fb296b88 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 13 Aug 2015 16:52:17 -0700 Subject: [PATCH 0235/1824] [SPARK-9922] [ML] rename StringIndexerReverse to IndexToString What `StringIndexerInverse` does is not strictly associated with `StringIndexer`, and the name is not clearly describing the transformation. Renaming to `IndexToString` might be better. ~~I also changed `invert` to `inverse` without arguments. `inputCol` and `outputCol` could be set after.~~ I also removed `invert`. jkbradley holdenk Author: Xiangrui Meng Closes #8152 from mengxr/SPARK-9922. --- .../spark/ml/feature/StringIndexer.scala | 34 +++++-------- .../spark/ml/feature/StringIndexerSuite.scala | 50 +++++++++++++------ 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9e4b0f0add61..9f6e7b6b6b27 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} @@ -59,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. + * + * @see [[IndexToString]] for the inverse transformation */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] @@ -170,34 +172,24 @@ class StringIndexerModel private[ml] ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } - - /** - * Return a model to perform the inverse transformation. - * Note: By default we keep the original columns during this transformation, so the inverse - * should only be used on new columns such as predicted labels. - */ - def invert(inputCol: String, outputCol: String): StringIndexerInverse = { - new StringIndexerInverse() - .setInputCol(inputCol) - .setOutputCol(outputCol) - .setLabels(labels) - } } /** * :: Experimental :: - * Transform a provided column back to the original input types using either the metadata - * on the input column, or if provided using the labels supplied by the user. - * Note: By default we keep the original columns during this transformation, - * so the inverse should only be used on new columns such as predicted labels. + * A [[Transformer]] that maps a column of string indices back to a new column of corresponding + * string values using either the ML attributes of the input column, or if provided using the labels + * supplied by the user. + * All original columns are kept during transformation. + * + * @see [[StringIndexer]] for converting strings into indices */ @Experimental -class StringIndexerInverse private[ml] ( +class IndexToString private[ml] ( override val uid: String) extends Transformer with HasInputCol with HasOutputCol { def this() = - this(Identifiable.randomUID("strIdxInv")) + this(Identifiable.randomUID("idxToStr")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -257,7 +249,7 @@ class StringIndexerInverse private[ml] ( } val indexer = udf { index: Double => val idx = index.toInt - if (0 <= idx && idx < values.size) { + if (0 <= idx && idx < values.length) { values(idx) } else { throw new SparkException(s"Unseen index: $index ??") @@ -268,7 +260,7 @@ class StringIndexerInverse private[ml] ( indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } - override def copy(extra: ParamMap): StringIndexerInverse = { + override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2d24914cb91f..fa918ce64877 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -53,19 +54,6 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { // a -> 0, b -> 2, c -> 1 val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) - // convert reverse our transform - val reversed = indexer.invert("labelIndex", "label2") - .transform(transformed) - .select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet) - // Check invert using only metadata - val inverse2 = new StringIndexerInverse() - .setInputCol("labelIndex") - .setOutputCol("label2") - val reversed2 = inverse2.transform(transformed).select("id", "label2") - assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet === - reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet) } test("StringIndexerUnseen") { @@ -125,4 +113,36 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val df = sqlContext.range(0L, 10L) assert(indexerModel.transform(df).eq(df)) } + + test("IndexToString params") { + val idxToStr = new IndexToString() + ParamsSuite.checkParams(idxToStr) + } + + test("IndexToString.transform") { + val labels = Array("a", "b", "c") + val df0 = sqlContext.createDataFrame(Seq( + (0, "a"), (1, "b"), (2, "c"), (0, "a") + )).toDF("index", "expected") + + val idxToStr0 = new IndexToString() + .setInputCol("index") + .setOutputCol("actual") + .setLabels(labels) + idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + + val attr = NominalAttribute.defaultAttr.withValues(labels) + val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected")) + + val idxToStr1 = new IndexToString() + .setInputCol("indexWithAttr") + .setOutputCol("actual") + idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + case Row(actual, expected) => + assert(actual === expected) + } + } } From 693949ba4096c01a0b41da2542ff316823464a16 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 17:33:37 -0700 Subject: [PATCH 0236/1824] [SPARK-8976] [PYSPARK] fix open mode in python3 This bug only happen on Python 3 and Windows. I tested this manually with python 3 and disable python daemon, no unit test yet. Author: Davies Liu Closes #8181 from davies/open_mode. --- python/pyspark/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 93df9002be37..42c2f8b75933 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -146,5 +146,5 @@ def process(): java_port = int(sys.stdin.readline()) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("a+", 65536) + sock_file = sock.makefile("rwb", 65536) main(sock_file, sock_file) From c50f97dafd2d5bf5a8351efcc1c8d3e2b87efc72 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 17:35:11 -0700 Subject: [PATCH 0237/1824] [SPARK-9943] [SQL] deserialized UnsafeHashedRelation should be serializable When the free memory in executor goes low, the cached broadcast objects need to serialized into disk, but currently the deserialized UnsafeHashedRelation can't be serialized , fail with NPE. This PR fixes that. cc rxin Author: Davies Liu Closes #8174 from davies/serialize_hashed. --- .../sql/execution/joins/HashedRelation.scala | 93 ++++++++++++------- .../execution/joins/HashedRelationSuite.scala | 14 +++ 2 files changed, 74 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index ea02076b41a6..6c0196c21a0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -247,40 +247,67 @@ private[joins] final class UnsafeHashedRelation( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - out.writeInt(hashTable.size()) - - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - val key = entry.getKey - val values = entry.getValue - - // write all the values as single byte array - var totalSize = 0L - var i = 0 - while (i < values.length) { - totalSize += values(i).getSizeInBytes + 4 + 4 - i += 1 + if (binaryMap != null) { + // This could happen when a cached broadcast object need to be dumped into disk to free memory + out.writeInt(binaryMap.numElements()) + + var buffer = new Array[Byte](64) + def write(addr: MemoryLocation, length: Int): Unit = { + if (buffer.length < length) { + buffer = new Array[Byte](length) + } + Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset, + buffer, Platform.BYTE_ARRAY_OFFSET, length) + out.write(buffer, 0, length) } - assert(totalSize < Integer.MAX_VALUE, "values are too big") - - // [key size] [values size] [key bytes] [values bytes] - out.writeInt(key.getSizeInBytes) - out.writeInt(totalSize.toInt) - out.write(key.getBytes) - i = 0 - while (i < values.length) { - // [num of fields] [num of bytes] [row bytes] - // write the integer in native order, so they can be read by UNSAFE.getInt() - if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { - out.writeInt(values(i).numFields()) - out.writeInt(values(i).getSizeInBytes) - } else { - out.writeInt(Integer.reverseBytes(values(i).numFields())) - out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + + val iter = binaryMap.iterator() + while (iter.hasNext) { + val loc = iter.next() + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(loc.getKeyLength) + out.writeInt(loc.getValueLength) + write(loc.getKeyAddress, loc.getKeyLength) + write(loc.getValueAddress, loc.getValueLength) + } + + } else { + assert(hashTable != null) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.length) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.length) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).numFields()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).numFields())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 } - out.write(values(i).getBytes) - i += 1 } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index c635b2d51f46..d33a967093ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -102,6 +102,14 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) assert(numDataRows.value.value === data.length) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2) + out2.flush() + // This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same + // as they are inserted + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } test("test serialization empty hash map") { @@ -119,5 +127,11 @@ class HashedRelationSuite extends SparkFunSuite { val toUnsafe = UnsafeProjection.create(schema) val row = toUnsafe(InternalRow(0)) assert(hashed2.get(row) === null) + + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + hashed2.writeExternal(out2) + out2.flush() + assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } } From 8187b3ae477e2b2987ae9acc5368d57b1d5653b2 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 13 Aug 2015 17:42:01 -0700 Subject: [PATCH 0238/1824] [SPARK-9580] [SQL] Replace singletons in SQL tests A fundamental limitation of the existing SQL tests is that *there is simply no way to create your own `SparkContext`*. This is a serious limitation because the user may wish to use a different master or config. As a case in point, `BroadcastJoinSuite` is entirely commented out because there is no way to make it pass with the existing infrastructure. This patch removes the singletons `TestSQLContext` and `TestData`, and instead introduces a `SharedSQLContext` that starts a context per suite. Unfortunately the singletons were so ingrained in the SQL tests that this patch necessarily needed to touch *all* the SQL test files. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/8111) Author: Andrew Or Closes #8111 from andrewor14/sql-tests-refactor. --- project/MimaExcludes.scala | 10 + project/SparkBuild.scala | 16 +- .../analysis/AnalysisErrorSuite.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 97 +----- .../org/apache/spark/sql/SQLImplicits.scala | 123 ++++++++ .../spark/sql/JavaApplySchemaSuite.java | 10 +- .../apache/spark/sql/JavaDataFrameSuite.java | 39 +-- .../org/apache/spark/sql/JavaUDFSuite.java | 10 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 15 +- .../apache/spark/sql/CachedTableSuite.scala | 14 +- .../spark/sql/ColumnExpressionSuite.scala | 37 +-- .../spark/sql/DataFrameAggregateSuite.scala | 10 +- .../spark/sql/DataFrameFunctionsSuite.scala | 18 +- .../spark/sql/DataFrameImplicitsSuite.scala | 6 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 10 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 6 +- .../apache/spark/sql/DataFrameStatSuite.scala | 19 +- .../org/apache/spark/sql/DataFrameSuite.scala | 12 +- .../spark/sql/DataFrameTungstenSuite.scala | 8 +- .../apache/spark/sql/DateFunctionsSuite.scala | 13 +- .../org/apache/spark/sql/JoinSuite.scala | 47 ++- .../apache/spark/sql/JsonFunctionsSuite.scala | 6 +- .../apache/spark/sql/ListTablesSuite.scala | 15 +- .../spark/sql/MathExpressionsSuite.scala | 28 +- .../org/apache/spark/sql/QueryTest.scala | 6 - .../scala/org/apache/spark/sql/RowSuite.scala | 7 +- .../org/apache/spark/sql/SQLConfSuite.scala | 19 +- .../apache/spark/sql/SQLContextSuite.scala | 13 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 29 +- .../sql/ScalaReflectionRelationSuite.scala | 17 +- .../apache/spark/sql/SerializationSuite.scala | 9 +- .../spark/sql/StringFunctionsSuite.scala | 7 +- .../scala/org/apache/spark/sql/TestData.scala | 197 ------------ .../scala/org/apache/spark/sql/UDFSuite.scala | 39 ++- .../spark/sql/UserDefinedTypeSuite.scala | 9 +- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +- .../columnar/PartitionBatchPruningSuite.scala | 31 +- .../spark/sql/execution/ExchangeSuite.scala | 3 +- .../spark/sql/execution/PlannerSuite.scala | 39 +-- .../execution/RowFormatConvertersSuite.scala | 22 +- .../spark/sql/execution/SortSuite.scala | 3 +- .../spark/sql/execution/SparkPlanTest.scala | 36 +-- .../sql/execution/TungstenSortSuite.scala | 21 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 14 +- .../UnsafeKVExternalSorterSuite.scala | 7 +- .../TungstenAggregationIteratorSuite.scala | 7 +- .../datasources/json/JsonSuite.scala | 42 ++- .../datasources/json/TestJsonData.scala | 40 ++- .../ParquetAvroCompatibilitySuite.scala | 14 +- .../parquet/ParquetCompatibilityTest.scala | 10 +- .../parquet/ParquetFilterSuite.scala | 8 +- .../datasources/parquet/ParquetIOSuite.scala | 6 +- .../ParquetPartitionDiscoverySuite.scala | 12 +- .../ParquetProtobufCompatibilitySuite.scala | 7 +- .../parquet/ParquetQuerySuite.scala | 43 ++- .../parquet/ParquetSchemaSuite.scala | 6 +- .../datasources/parquet/ParquetTest.scala | 15 +- .../ParquetThriftCompatibilitySuite.scala | 8 +- .../sql/execution/debug/DebuggingSuite.scala | 6 +- .../execution/joins/HashedRelationSuite.scala | 10 +- .../sql/execution/joins/InnerJoinSuite.scala | 248 ++++++++------- .../sql/execution/joins/OuterJoinSuite.scala | 125 ++++---- .../sql/execution/joins/SemiJoinSuite.scala | 113 +++---- .../execution/metric/SQLMetricsSuite.scala | 62 ++-- .../sql/execution/ui/SQLListenerSuite.scala | 14 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 20 +- .../sources/CreateTableAsSelectSuite.scala | 18 +- .../sql/sources/DDLSourceLoadSuite.scala | 3 +- .../spark/sql/sources/DDLTestSuite.scala | 11 +- .../spark/sql/sources/DataSourceTest.scala | 17 +- .../spark/sql/sources/FilteredScanSuite.scala | 9 +- .../spark/sql/sources/InsertSuite.scala | 30 +- .../sql/sources/PartitionedWriteSuite.scala | 14 +- .../spark/sql/sources/PrunedScanSuite.scala | 11 +- .../spark/sql/sources/SaveLoadSuite.scala | 26 +- .../spark/sql/sources/TableScanSuite.scala | 15 +- .../apache/spark/sql/test/SQLTestData.scala | 290 ++++++++++++++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 92 +++++- .../spark/sql/test/SharedSQLContext.scala | 68 ++++ .../spark/sql/test/TestSQLContext.scala | 46 ++- .../hive/thriftserver/UISeleniumSuite.scala | 2 - .../sql/hive/HiveMetastoreCatalogSuite.scala | 5 +- .../spark/sql/hive/HiveParquetSuite.scala | 11 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 4 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 5 +- .../hive/ParquetHiveCompatibilitySuite.scala | 3 +- .../org/apache/spark/sql/hive/UDFSuite.scala | 4 +- .../execution/AggregationQuerySuite.scala | 9 +- .../sql/hive/execution/HiveExplainSuite.scala | 4 +- .../sql/hive/execution/SQLQuerySuite.scala | 3 +- .../execution/ScriptTransformationSuite.scala | 3 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 4 +- .../apache/spark/sql/hive/parquetSuites.scala | 3 +- .../CommitFailureTestRelationSuite.scala | 6 +- .../sql/sources/hadoopFsRelationSuites.scala | 5 +- 96 files changed, 1460 insertions(+), 1203 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TestData.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala rename sql/core/src/{main => test}/scala/org/apache/spark/sql/test/TestSQLContext.scala (54%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 784f83c10e02..88745dc086a0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -178,6 +178,16 @@ object MimaExcludes { // SPARK-4751 Dynamic allocation for standalone mode ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.supportDynamicAllocation") + ) ++ Seq( + // SPARK-9580: Remove SQL test singletons + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext$") ) ++ Seq( // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs ProblemFilters.exclude[IncompatibleResultTypeProblem]( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 74f815f941d5..04e0d49b178c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -319,6 +319,8 @@ object SQL { lazy val settings = Seq( initialCommands in console := """ + |import org.apache.spark.SparkContext + |import org.apache.spark.sql.SQLContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -328,9 +330,14 @@ object SQL { |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.test.TestSQLContext._ - |import org.apache.spark.sql.types._""".stripMargin, - cleanupCommands in console := "sparkContext.stop()" + |import org.apache.spark.sql.types._ + | + |val sc = new SparkContext("local[*]", "dev-shell") + |val sqlContext = new SQLContext(sc) + |import sqlContext.implicits._ + |import sqlContext._ + """.stripMargin, + cleanupCommands in console := "sc.stop()" ) } @@ -340,8 +347,6 @@ object Hive { javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), - // Multiple queries rely on the TestHive singleton. See comments there for more details. - parallelExecution in Test := false, // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. scalacOptions <<= scalacOptions map { currentOpts: Seq[String] => @@ -349,6 +354,7 @@ object Hive { }, initialCommands in console := """ + |import org.apache.spark.SparkContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 63b475b6366c..f60d11c988ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.types._ @@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode { override def output: Seq[Attribute] = Nil } -class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter { +class AnalysisErrorSuite extends AnalysisTest { import TestRelations._ def errorTest( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4bf00b3399e7..53de10d5fa9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConversions._ import scala.collection.immutable -import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -41,10 +40,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -334,97 +332,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ @Experimental - object implicits extends Serializable { - // scalastyle:on - - /** - * Converts $"col name" into an [[Column]]. - * @since 1.3.0 - */ - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 1.3.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** - * Creates a DataFrame from an RDD of case classes or tuples. - * @since 1.3.0 - */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { - DataFrameHolder(self.createDataFrame(rdd)) - } - - /** - * Creates a DataFrame from a local Seq of Product. - * @since 1.3.0 - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = - { - DataFrameHolder(self.createDataFrame(data)) - } - - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. - - /** - * Creates a single column DataFrame from an RDD[Int]. - * @since 1.3.0 - */ - implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { - val dataType = IntegerType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setInt(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[Long]. - * @since 1.3.0 - */ - implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { - val dataType = LongType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setLong(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[String]. - * @since 1.3.0 - */ - implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { - val dataType = StringType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.update(0, UTF8String.fromString(v)) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } + object implicits extends SQLImplicits with Serializable { + protected override def _sqlContext: SQLContext = self } + // scalastyle:on /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala new file mode 100644 index 000000000000..5f82372700f2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types.StructField +import org.apache.spark.unsafe.types.UTF8String + +/** + * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + */ +private[sql] abstract class SQLImplicits { + protected def _sqlContext: SQLContext + + /** + * Converts $"col name" into an [[Column]]. + * @since 1.3.0 + */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } + + /** + * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. + * @since 1.3.0 + */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** + * Creates a DataFrame from an RDD of case classes or tuples. + * @since 1.3.0 + */ + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(_sqlContext.createDataFrame(rdd)) + } + + /** + * Creates a DataFrame from a local Seq of Product. + * @since 1.3.0 + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = + { + DataFrameHolder(_sqlContext.createDataFrame(data)) + } + + // Do NOT add more implicit conversions. They are likely to break source compatibility by + // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous + // because of [[DoubleRDDFunctions]]. + + /** + * Creates a single column DataFrame from an RDD[Int]. + * @since 1.3.0 + */ + implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { + val dataType = IntegerType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setInt(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[Long]. + * @since 1.3.0 + */ + implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { + val dataType = LongType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setLong(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[String]. + * @since 1.3.0 + */ + implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { + val dataType = StringType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.update(0, UTF8String.fromString(v)) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index e912eb835d16..bf693c7c393f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -27,6 +27,7 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -34,7 +35,6 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -48,14 +48,16 @@ public class JavaApplySchemaSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - javaCtx = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext context = new SparkContext("local[*]", "testing"); + javaCtx = new JavaSparkContext(context); + sqlContext = new SQLContext(context); } @After public void tearDown() { - javaCtx = null; + sqlContext.sparkContext().stop(); sqlContext = null; + javaCtx = null; } public static class Person implements Serializable { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 7302361ab9fd..7abdd3db8034 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,44 +17,45 @@ package test.org.apache.spark.sql; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import scala.collection.JavaConversions; +import scala.collection.Seq; + import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; +import org.junit.*; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; -import org.junit.*; - -import scala.collection.JavaConversions; -import scala.collection.Seq; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; - -import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; - private transient SQLContext context; + private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - TestData$.MODULE$.testData(); - jsc = new JavaSparkContext(TestSQLContext.sparkContext()); - context = TestSQLContext$.MODULE$; + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); } @After public void tearDown() { - jsc = null; + context.sparkContext().stop(); context = null; + jsc = null; } @Test @@ -230,7 +231,7 @@ public void testCovariance() { @Test public void testSampleBy() { - DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 79d92734ff37..bb02b58cca9b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -23,12 +23,12 @@ import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -40,12 +40,16 @@ public class JavaUDFSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); } @After public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; } @SuppressWarnings("unchecked") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 2706e01bd28a..6f9e7f68dc39 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -21,13 +21,14 @@ import java.io.IOException; import java.util.*; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -52,8 +53,9 @@ private void checkAnswer(DataFrame actual, List expected) { @Before public void setUp() throws IOException { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = @@ -71,6 +73,13 @@ public void setUp() throws IOException { df.registerTempTable("jsonTable"); } + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; + } + @Test public void saveAndLoad() { Map options = new HashMap(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index a88df91b1001..af7590c3d3c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,24 +18,20 @@ package org.apache.spark.sql import scala.concurrent.duration._ -import scala.language.{implicitConversions, postfixOps} +import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.{StorageLevel, RDDBlockId} -case class BigData(s: String) +private case class BigData(s: String) -class CachedTableSuite extends QueryTest { - TestData // Load test tables. - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql +class CachedTableSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def rddIdOf(tableName: String): Int = { val executedPlan = ctx.table(tableName).queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 6a09a3b72c08..ee74e3e83da5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -21,16 +21,20 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ +class ColumnExpressionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - override def sqlContext(): SQLContext = ctx + private lazy val booleanData = { + ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + } test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -258,7 +262,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) checkAnswer( - ctx.sql("select isnull(null), isnull(1)"), + sql("select isnull(null), isnull(1)"), Row(true, false)) } @@ -268,7 +272,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) checkAnswer( - ctx.sql("select isnotnull(null), isnotnull('a')"), + sql("select isnotnull(null), isnotnull('a')"), Row(false, true)) } @@ -289,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( - ctx.sql("select isnan(15), isnan('invalid')"), + sql("select isnan(15), isnan('invalid')"), Row(false, false)) } @@ -309,7 +313,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) testData.registerTempTable("t") checkAnswer( - ctx.sql( + sql( "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + " nanvl(b, e), nanvl(e, f) from t"), Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) @@ -433,13 +437,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { } } - val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( - Row(false, false) :: - Row(false, true) :: - Row(true, false) :: - Row(true, true) :: Nil), - StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) - test("&&") { checkAnswer( booleanData.filter($"a" && true), @@ -523,7 +520,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT upper('aB'), ucase('cDe')"), + sql("SELECT upper('aB'), ucase('cDe')"), Row("AB", "CDE")) } @@ -544,7 +541,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT lower('aB'), lcase('cDe')"), + sql("SELECT lower('aB'), lcase('cDe')"), Row("ab", "cde")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f9cff7440a76..72cf7aab0b97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{BinaryType, DecimalType} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DecimalType -class DataFrameAggregateSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("groupBy") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 03116a374f3b..9d965258e389 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") @@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest { test("constant functions") { checkAnswer( - ctx.sql("SELECT E()"), + sql("SELECT E()"), Row(scala.math.E) ) checkAnswer( - ctx.sql("SELECT PI()"), + sql("SELECT PI()"), Row(scala.math.Pi) ) } @@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("nvl function") { checkAnswer( - ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), Row("x", "y", null)) } @@ -222,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row(-1) ) checkAnswer( - ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), + sql("SELECT least(a, 2) as l from testData2 order by l"), Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) ) } @@ -233,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row(3) ) checkAnswer( - ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), + sql("SELECT greatest(a, 2) as g from testData2 order by g"), Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index fbb30706a494..e5d7d63441a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -class DataFrameImplicitsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("RDD of tuples") { checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e1c6c706242d..e2716d7841d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameJoinSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameJoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -59,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") .collect().toSeq) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index dbe3b44ee2c7..cdaa14ac8078 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameNaFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8f5984e4a8ce..28bdd6f83b68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,20 +19,17 @@ package org.apache.spark.sql import java.util.Random -import org.scalatest.Matchers._ - import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameStatSuite extends QueryTest { - - private val sqlCtx = org.apache.spark.sql.test.TestSQLContext - import sqlCtx.implicits._ +class DataFrameStatSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private def toLetter(i: Int): String = (i + 97).toChar.toString test("sample with replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = true, 0.05, seed = 13), Seq(5, 10, 52, 73).map(Row(_)) @@ -41,7 +38,7 @@ class DataFrameStatSuite extends QueryTest { test("sample without replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), Seq(16, 23, 88, 100).map(Row(_)) @@ -50,7 +47,7 @@ class DataFrameStatSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -167,7 +164,7 @@ class DataFrameStatSuite extends QueryTest { } test("Frequent Items 2") { - val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4) + val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4) // this is a regression test, where when merging partitions, we omitted values with higher // counts than those that existed in the map when the map was full. This test should also fail // if anything like SPARK-9614 is observed once again @@ -185,7 +182,7 @@ class DataFrameStatSuite extends QueryTest { } test("sampleBy") { - val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val df = ctx.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2feec29955bc..10bfa9b64f00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -23,18 +23,12 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ -import org.apache.spark.sql.execution.datasources.json.JSONRelation -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext} -class DataFrameSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ - - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("analysis error should be eagerly reported") { // Eager analysis. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index bf8ef9a97bc6..77907e91363e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** @@ -27,10 +27,8 @@ import org.apache.spark.sql.types._ * This is here for now so I can make sure Tungsten project is tested without refactoring existing * end-to-end test infra. In the long run this should just go away. */ -class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 17897caf952a..9080c53c491a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,19 +22,18 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval -class DateFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - - import ctx.implicits._ +class DateFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("function current_date") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) val d2 = DateTimeUtils.fromJavaDate( - ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) } @@ -44,9 +43,9 @@ class DateFunctionsSuite extends QueryTest { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) // Execution in one query should return the same value - checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true)) - assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( 0).getTime - System.currentTimeMillis()) < 5000) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ae07eaf91c87..f5c5046a8ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,22 +17,15 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { - // Ensures tables are loaded. - TestData +class JoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.logicalPlanToSparkQuery + setupTestData() test("equi-join is hash-join") { val x = testData2.as("x") @@ -43,7 +36,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = ctx.sql(sqlString) + val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -126,7 +119,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") for (sortMergeJoinEnabled <- Seq(true, false)) { withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { @@ -141,12 +134,12 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { } } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", @@ -167,7 +160,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { @@ -279,7 +272,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -293,7 +286,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -339,7 +332,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -348,7 +341,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -400,7 +393,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -409,7 +402,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -424,7 +417,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -439,7 +432,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -450,7 +443,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { Seq( @@ -469,11 +462,11 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("left semi join") { - val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 71c26a6f8d36..045fea82e4c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -class JsonFunctionsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class JsonFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("function get_json_object") { val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2089660c52bf..babf8835d254 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -class ListTablesSuite extends QueryTest with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { + import testImplicits._ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") @@ -42,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -55,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -67,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - ctx.sql( + sql( "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 8cf2ef5957d8..30289c3c1d09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} +import org.apache.spark.sql.test.SharedSQLContext private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest { - +class MathExpressionsSuite extends QueryTest with SharedSQLContext { import MathExpressionsTestData._ - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + import testImplicits._ private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() @@ -149,7 +147,7 @@ class MathExpressionsSuite extends QueryTest { test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) checkAnswer( - ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + sql("SELECT degrees(0), degrees(1), degrees(1.5)"), Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) ) } @@ -157,7 +155,7 @@ class MathExpressionsSuite extends QueryTest { test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) checkAnswer( - ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + sql("SELECT radians(0), radians(1), radians(1.5)"), Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) ) } @@ -169,7 +167,7 @@ class MathExpressionsSuite extends QueryTest { test("ceil and ceiling") { testOneToOneMathFunction(ceil, math.ceil) checkAnswer( - ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), Row(0.0, 1.0, 2.0)) } @@ -214,7 +212,7 @@ class MathExpressionsSuite extends QueryTest { val pi = 3.1415 checkAnswer( - ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) @@ -233,7 +231,7 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction[Double](signum, math.signum) checkAnswer( - ctx.sql("SELECT sign(10), signum(-11)"), + sql("SELECT sign(10), signum(-11)"), Row(1, -1)) } @@ -241,7 +239,7 @@ class MathExpressionsSuite extends QueryTest { testTwoToOneMathFunction(pow, pow, math.pow) checkAnswer( - ctx.sql("SELECT pow(1, 2), power(2, 1)"), + sql("SELECT pow(1, 2), power(2, 1)"), Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) ) } @@ -280,7 +278,7 @@ class MathExpressionsSuite extends QueryTest { test("log / ln") { testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) checkAnswer( - ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + sql("SELECT ln(0), ln(1), ln(1.5)"), Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) ) } @@ -375,7 +373,7 @@ class MathExpressionsSuite extends QueryTest { df.select(log2("b") + log2("a")), Row(1)) - checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) } test("sqrt") { @@ -384,13 +382,13 @@ class MathExpressionsSuite extends QueryTest { df.select(sqrt("a"), sqrt("b")), Row(1.0, 2.0)) - checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) } test("negative") { checkAnswer( - ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + sql("SELECT negative(1), negative(0), negative(-1)"), Row(-1, 0, 1)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 98ba3c99283a..4adcefb7dc4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -71,12 +71,6 @@ class QueryTest extends PlanTest { checkAnswer(df, expectedAnswer.collect()) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } - /** * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 8a679c7865d6..795d4e983f27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class RowSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 75791e9d53c2..7699adadd9cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SharedSQLContext -class SQLConfSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" private val testVal = "test.val.0" @@ -52,21 +51,21 @@ class SQLConfSuite extends QueryTest { test("parse SQL set commands") { ctx.conf.clear() - ctx.sql(s"set $testKey=$testVal") + sql(s"set $testKey=$testVal") assert(ctx.getConf(testKey, testVal + "_") === testVal) assert(ctx.getConf(testKey, testVal + "_") === testVal) - ctx.sql("set some.property=20") + sql("set some.property=20") assert(ctx.getConf("some.property", "0") === "20") - ctx.sql("set some.property = 40") + sql("set some.property = 40") assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - ctx.sql(s"set $key=$vs") + sql(s"set $key=$vs") assert(ctx.getConf(key, "0") === vs) - ctx.sql(s"set $key=") + sql(s"set $key=") assert(ctx.getConf(key, "0") === "") ctx.conf.clear() @@ -74,14 +73,14 @@ class SQLConfSuite extends QueryTest { test("deprecated property") { ctx.conf.clear() - ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") assert(ctx.conf.numShufflePartitions === 10) } test("invalid conf value") { ctx.conf.clear() val e = intercept[IllegalArgumentException] { - ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index c8d8796568a4..007be1295077 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext -class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(ctx) + try { + SQLContext.setLastInstantiatedContext(ctx) + } finally { + super.afterAll() + } } test("getOrCreate instantiates SQLContext") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b14ef9bab90c..8c2c328f8191 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,28 +19,23 @@ package org.apache.spark.sql import java.sql.Timestamp -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { - // Make sure the tables are loaded. - TestData +class SQLQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql + setupTestData() test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") @@ -60,7 +55,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("show functions") { - checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + checkAnswer(sql("SHOW functions"), + FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) } test("describe functions") { @@ -178,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUDF = udf(() => UUID.randomUUID().toString) + val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString) val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) @@ -712,9 +708,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer( sql( - """ - |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 - """.stripMargin), + "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), Row(2, 1, 2, 2, 1)) } @@ -1161,7 +1155,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) - validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql( + "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } test("SPARK-3371 Renaming a function expression with group by gives error") { @@ -1627,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .toDF("num", "str") df.registerTempTable("1one") - checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10)) + checkAnswer(sql("select count(num) from 1one"), Row(10)) sqlContext.dropTempTable("1one") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ab6d3dd96d27..295f02f9a7b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext case class ReflectData( stringField: String, @@ -71,17 +72,15 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) Seq(data).toDF().registerTempTable("reflectData") - assert(ctx.sql("SELECT * FROM reflectData").collect().head === + assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = NullReflectData(null, null, null, null, null, null, null) Seq(data).toDF().registerTempTable("reflectNullData") - assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = OptionalReflectData(None, None, None, None, None, None, None) Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query binary data") { Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = ctx.sql("SELECT data FROM reflectBinary") + val result = sql("SELECT data FROM reflectBinary") .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Nested(None, "abc"))) Seq(data).toDF().registerTempTable("reflectComplexData") - assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === + assert(sql("SELECT * FROM reflectComplexData").collect().head === Row( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index e55c9e460b79..45d0ee4a8e74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.SharedSQLContext -class SerializationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(ctx.sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) + val _sqlContext = new SQLContext(sqlContext.sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ca298b243441..cc95eede005d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.Decimal -class StringFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class StringFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("string concat") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala deleted file mode 100644 index bd9729c431f3..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test._ - - -case class TestData(key: Int, value: String) - -object TestData { - val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") - - val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() - negativeData.registerTempTable("negativeData") - - case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts = - TestSQLContext.sparkContext.parallelize( - LargeAndSmallInts(2147483644, 1) :: - LargeAndSmallInts(1, 2) :: - LargeAndSmallInts(2147483645, 1) :: - LargeAndSmallInts(2, 2) :: - LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toDF() - largeAndSmallInts.registerTempTable("largeAndSmallInts") - - case class TestData2(a: Int, b: Int) - val testData2 = - TestSQLContext.sparkContext.parallelize( - TestData2(1, 1) :: - TestData2(1, 2) :: - TestData2(2, 1) :: - TestData2(2, 2) :: - TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toDF() - testData2.registerTempTable("testData2") - - case class DecimalData(a: BigDecimal, b: BigDecimal) - - val decimalData = - TestSQLContext.sparkContext.parallelize( - DecimalData(1, 1) :: - DecimalData(1, 2) :: - DecimalData(2, 1) :: - DecimalData(2, 2) :: - DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toDF() - decimalData.registerTempTable("decimalData") - - case class BinaryData(a: Array[Byte], b: Int) - val binaryData = - TestSQLContext.sparkContext.parallelize( - BinaryData("12".getBytes(), 1) :: - BinaryData("22".getBytes(), 5) :: - BinaryData("122".getBytes(), 3) :: - BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toDF() - binaryData.registerTempTable("binaryData") - - case class TestData3(a: Int, b: Option[Int]) - val testData3 = - TestSQLContext.sparkContext.parallelize( - TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toDF() - testData3.registerTempTable("testData3") - - case class UpperCaseData(N: Int, L: String) - val upperCaseData = - TestSQLContext.sparkContext.parallelize( - UpperCaseData(1, "A") :: - UpperCaseData(2, "B") :: - UpperCaseData(3, "C") :: - UpperCaseData(4, "D") :: - UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toDF() - upperCaseData.registerTempTable("upperCaseData") - - case class LowerCaseData(n: Int, l: String) - val lowerCaseData = - TestSQLContext.sparkContext.parallelize( - LowerCaseData(1, "a") :: - LowerCaseData(2, "b") :: - LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toDF() - lowerCaseData.registerTempTable("lowerCaseData") - - case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) - val arrayData = - TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: - ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) - arrayData.toDF().registerTempTable("arrayData") - - case class MapData(data: scala.collection.Map[Int, String]) - val mapData = - TestSQLContext.sparkContext.parallelize( - MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: - MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: - MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: - MapData(Map(1 -> "a4", 2 -> "b4")) :: - MapData(Map(1 -> "a5")) :: Nil) - mapData.toDF().registerTempTable("mapData") - - case class StringData(s: String) - val repeatedData = - TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.toDF().registerTempTable("repeatedData") - - val nullableRepeatedData = - TestSQLContext.sparkContext.parallelize( - List.fill(2)(StringData(null)) ++ - List.fill(2)(StringData("test"))) - nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData") - - case class NullInts(a: Integer) - val nullInts = - TestSQLContext.sparkContext.parallelize( - NullInts(1) :: - NullInts(2) :: - NullInts(3) :: - NullInts(null) :: Nil - ).toDF() - nullInts.registerTempTable("nullInts") - - val allNulls = - TestSQLContext.sparkContext.parallelize( - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: - NullInts(null) :: Nil).toDF() - allNulls.registerTempTable("allNulls") - - case class NullStrings(n: Int, s: String) - val nullStrings = - TestSQLContext.sparkContext.parallelize( - NullStrings(1, "abc") :: - NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil).toDF() - nullStrings.registerTempTable("nullStrings") - - case class TableName(tableName: String) - TestSQLContext - .sparkContext - .parallelize(TableName("test") :: Nil) - .toDF() - .registerTempTable("tableName") - - val unparsedStrings = - TestSQLContext.sparkContext.parallelize( - "1, A1, true, null" :: - "2, B2, false, null" :: - "3, C3, true, null" :: - "4, D4, true, 2147483644" :: Nil) - - case class IntField(i: Int) - // An RDD with 4 elements and 8 partitions - val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - withEmptyParts.toDF().registerTempTable("withEmptyParts") - - case class Person(id: Int, name: String, age: Int) - case class Salary(personId: Int, salary: Double) - val person = TestSQLContext.sparkContext.parallelize( - Person(0, "mike", 30) :: - Person(1, "jim", 20) :: Nil).toDF() - person.registerTempTable("person") - val salary = TestSQLContext.sparkContext.parallelize( - Salary(0, 2000.0) :: - Salary(1, 1000.0) :: Nil).toDF() - salary.registerTempTable("salary") - - case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) - val complexData = - TestSQLContext.sparkContext.parallelize( - ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) - :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) - :: Nil).toDF() - complexData.registerTempTable("complexData") -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 183dc3407b3a..eb275af101e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -case class FunctionResult(f1: String, f2: String) +private case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with SQLTestUtils { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - override def sqlContext(): SQLContext = ctx +class UDFSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame @@ -57,7 +54,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("SPARK-8003 spark_partition_id") { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") - checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) + checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) ctx.dropTempTable("tmp_table") } @@ -66,9 +63,9 @@ class UDFSuite extends QueryTest with SQLTestUtils { val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") - val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) - assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) ctx.dropTempTable("test_table") } } @@ -91,17 +88,17 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("Simple UDF") { ctx.udf.register("strLenScala", (_: String).length) - assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { ctx.udf.register("random0", () => { Math.random()}) - assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { ctx.udf.register("strLenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { @@ -112,7 +109,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("integerData") val result = - ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") assert(result.count() === 20) } @@ -124,7 +121,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT g, SUM(v) as s | FROM groupData @@ -143,7 +140,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT SUM(v) | FROM groupData @@ -163,7 +160,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { df.registerTempTable("groupData") val result = - ctx.sql( + sql( """ | SELECT timesHundred(SUM(v)) as v100 | FROM groupData @@ -178,7 +175,7 @@ class UDFSuite extends QueryTest with SQLTestUtils { ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - ctx.sql("SELECT returnStruct('test', 'test2') as ret") + sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } @@ -186,12 +183,12 @@ class UDFSuite extends QueryTest with SQLTestUtils { test("udf that is transformed") { ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { ctx.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. - assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 9181222f6922..b6d279ae4726 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private lazy val pointsRDD = Seq( MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), @@ -94,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest { ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - ctx.sql("SELECT testType(features) from points"), + sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 9bca4e7e660d..952637c5f9cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql.columnar import java.sql.{Date, Timestamp} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY -class InMemoryColumnarQuerySuite extends QueryTest { - // Make sure the tables are loaded. - TestData +class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.{logicalPlanToSparkQuery, sql} + setupTestData() test("simple columnar query") { val plan = ctx.executePlan(testData.logicalPlan).executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2c0879927a12..ab2644eb4581 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} - import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ -class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { + super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) @@ -44,19 +43,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - } - - override protected def afterAll(): Unit = { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - } - - before { ctx.cacheTable("pruningData") } - after { - ctx.uncacheTable("pruningData") + override protected def afterAll(): Unit = { + try { + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + ctx.uncacheTable("pruningData") + } finally { + super.afterAll() + } } // Comparisons @@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = ctx.sql(query) + val df = sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 79e903c2bbd4..8998f5111124 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.test.SharedSQLContext -class ExchangeSuite extends SparkPlanTest { +class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5582caa0d366..937a10854353 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.rdd.RDD -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans._ @@ -27,19 +27,18 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution} -class PlannerSuite extends SparkFunSuite with SQLTestUtils { +class PlannerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ - override def sqlContext: SQLContext = TestSQLContext + setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val _ctx = ctx + import _ctx.planner._ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = plannedOption.getOrElse( @@ -54,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { } test("unions are collapsed") { + val _ctx = ctx + import _ctx.planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -81,14 +82,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) val fields = fieldTypes.zipWithIndex.map { case (dataType, index) => StructField(s"c${index}", dataType, true) } :+ StructField("key", IntegerType, true) val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) - createDataFrame(rowRDD, schema).registerTempTable("testLimit") + val rowRDD = ctx.sparkContext.parallelize(row :: Nil) + ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit") val planned = sql( """ @@ -102,10 +103,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - dropTempTable("testLimit") + ctx.dropTempTable("testLimit") } - val origThreshold = conf.autoBroadcastJoinThreshold + val origThreshold = ctx.conf.autoBroadcastJoinThreshold val simpleTypes = NullType :: @@ -137,18 +138,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { checkPlan(complexTypes, newThreshold = 901617) - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("InMemoryRelation statistics propagation") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) + val origThreshold = ctx.conf.autoBroadcastJoinThreshold + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") val a = testData.as("a") - val b = table("tiny").as("b") + val b = ctx.table("tiny").as("b") val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } @@ -157,12 +158,12 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("efficient limit -> project -> sort") { val query = testData.sort('key).select('value).limit(2).logicalPlan - val planned = planner.TakeOrderedAndProject(query) + val planned = ctx.planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index dd08e9025a92..ef6ad59b71fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType} import org.apache.spark.unsafe.types.UTF8String -class RowFormatConvertersSuite extends SparkPlanTest { +class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { case c: ConvertToUnsafe => c @@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest { test("planner should insert unsafe->safe conversions when required") { val plan = Limit(10, outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) } test("filter can process unsafe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) } @@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest { test("union requires all of its input rows' formats to agree") { val plan = Union(Seq(outputsSafe, outputsUnsafe)) assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("union can process safe rows") { val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(!preparedPlan.outputsUnsafeRows) } test("union can process unsafe rows") { val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) + val preparedPlan = ctx.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("round trip with ConvertToUnsafe and ConvertToSafe") { val input = Seq(("hello", 1), ("world", 2)) checkAnswer( - TestSQLContext.createDataFrame(input), + ctx.createDataFrame(input), plan => ConvertToSafe(ConvertToUnsafe(plan)), input.map(Row.fromTuple) ) } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(TestSQLContext) + SparkPlan.currentContext.set(ctx) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index a2c10fdaf6cd..8fa77b0fcb7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.SharedSQLContext -class SortSuite extends SparkPlanTest { +class SortSuite extends SparkPlanTest with SharedSQLContext { // This test was originally added as an example of how to use [[SparkPlanTest]]; // it's not designed to be a comprehensive test of ExternalSort. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index f46855edfe0d..3a87f374d94b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -17,29 +17,27 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row} - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.util._ + /** * Base class for writing tests for individual physical operators. For an example of how this * class's test helper methods can be used, see [[SortSuite]]. */ -class SparkPlanTest extends SparkFunSuite { - - protected def sqlContext: SQLContext = TestSQLContext +private[sql] abstract class SparkPlanTest extends SparkFunSuite { + protected def _sqlContext: SQLContext /** * Creates a DataFrame from a local Seq of Product. */ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - sqlContext.implicits.localSeqToDataFrameHolder(data) + _sqlContext.implicits.localSeqToDataFrameHolder(data) } /** @@ -100,7 +98,7 @@ class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -124,7 +122,7 @@ class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -151,13 +149,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + _sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, sqlContext) + executePlan(expectedOutputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -172,7 +170,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -212,12 +210,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + _sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, _sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -280,10 +278,10 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.prepareForExecution.execute( + val resolvedPlan = _sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 88bce0e319f9..3158458edb83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -19,25 +19,28 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * A test suite that generates randomized data to test the [[TungstenSort]] operator. */ -class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { +class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { override def beforeAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + super.beforeAll() + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + try { + ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + } finally { + super.afterAll() + } } test("sort followed by limit") { @@ -61,7 +64,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { } test("sorting updates peak execution memory") { - val sc = TestSQLContext.sparkContext + val sc = ctx.sparkContext AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), @@ -80,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = TestSQLContext.createDataFrame( - TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + val inputDf = ctx.createDataFrame( + ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) assert(TungstenSort.supportsSchema(inputDf.schema)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index e03473041c3e..d1f0b2b1fc52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String * * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases. */ -class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with SharedSQLContext { import UnsafeFixedWidthAggregationMap._ @@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext - // Memory consumption in the beginning of the task. val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() @@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting with an empty map") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, @@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { } testWithMemoryLeakDetection("test external sorting with empty records") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext // Memory consumption in the beginning of the task. val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index a9515a03acf2..d3be568a8758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -23,15 +23,14 @@ import org.apache.spark._ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. */ -class UnsafeKVExternalSorterSuite extends SparkFunSuite { - +class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType) private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType) @@ -109,8 +108,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { inputData: Seq[(InternalRow, InternalRow)], pageSize: Long, spill: Boolean): Unit = { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) val shuffleMemMgr = new TestShuffleMemoryManager diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index ac22c2f3c0a5..5fdb82b06772 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -21,15 +21,12 @@ import org.apache.spark._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.memory.TaskMemoryManager -class TungstenAggregationIteratorSuite extends SparkFunSuite { +class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext { test("memory acquired on construction") { - // set up environment - val ctx = TestSQLContext - val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) TaskContext.setTaskContext(taskContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 73d562189781..1174b27732f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -24,22 +24,16 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf} -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils -class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { - - protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext - override def sqlContext: SQLContext = ctx // used by SQLTestUtils - - import ctx.sql - import ctx.implicits._ +class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { + import testImplicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -596,7 +590,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + ctx.read.schema(schema).json(path) + .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) assert(relationWithSchema.schema === schema) @@ -1040,31 +1035,29 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { } test("JSONRelation equality test") { - val context = org.apache.spark.sql.test.TestSQLContext - val relation0 = new JSONRelation( Some(empty), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( Some(singleRow), 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(context) + None, None)(ctx) val logicalRelation3 = LogicalRelation(relation3) assert(relation0 !== relation1) @@ -1089,14 +1082,14 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = ResolvedDataSource( - context, + ctx, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, options = Map("path" -> path)) val d2 = ResolvedDataSource( - context, + ctx, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, @@ -1162,11 +1155,12 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData { "abd") ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") - checkAnswer( - sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) - checkAnswer( - sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) - checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5)) + checkAnswer(sql( + "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 6b62c9a003df..2864181cf91d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -trait TestJsonData { - - protected def ctx: SQLContext +private[json] trait TestJsonData { + protected def _sqlContext: SQLContext def primitiveFieldAndType: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -36,7 +35,7 @@ trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -47,14 +46,14 @@ trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -65,14 +64,14 @@ trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -80,7 +79,7 @@ trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -96,7 +95,7 @@ trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -150,7 +149,7 @@ trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -158,7 +157,7 @@ trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -167,21 +166,21 @@ trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -190,7 +189,7 @@ trait TestJsonData { """]""" :: Nil) def emptyRecords: RDD[String] = - ctx.sparkContext.parallelize( + _sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -198,9 +197,8 @@ trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) - lazy val singleRow: RDD[String] = - ctx.sparkContext.parallelize( - """{"a":123}""" :: Nil) - def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) + lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + + def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 866a975ad540..82d40e2b61a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -27,18 +27,16 @@ import org.apache.avro.generic.IndexedRecord import org.apache.hadoop.fs.Path import org.apache.parquet.avro.AvroParquetWriter -import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.parquet.test.avro._ +import org.apache.spark.sql.test.SharedSQLContext -class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { +class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ - override val sqlContext: SQLContext = TestSQLContext - private def withWriter[T <: IndexedRecord] (path: String, schema: Schema) - (f: AvroParquetWriter[T] => Unit) = { + (f: AvroParquetWriter[T] => Unit): Unit = { val writer = new AvroParquetWriter[T](new Path(path), schema) try f(writer) finally writer.close() } @@ -129,7 +127,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest { } test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") { - import sqlContext.implicits._ + import testImplicits._ withTempPath { dir => val path = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 0ea64aa2a509..b3406729fcc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -22,16 +22,18 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.parquet.hadoop.ParquetFileReader import org.apache.parquet.schema.MessageType -import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.QueryTest -abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll { - def readParquetSchema(path: String): MessageType = { +/** + * Helper class for testing Parquet compatibility. + */ +private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest { + protected def readParquetSchema(path: String): MessageType = { readParquetSchema(path, { path => !path.getName.startsWith("_") }) } - def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { + protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { val fsPath = new Path(path) val fs = fsPath.getFileSystem(configuration) val parquetFiles = fs.listStatus(fsPath, new PathFilter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 7dd9680d8cd6..5b4e568bb983 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -39,8 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext +class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private def checkFilterPredicate( df: DataFrame, @@ -301,7 +301,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } test("SPARK-6554: don't push down predicates which reference partition columns") { - import sqlContext.implicits._ + import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index cb166349fdb2..d819f3ab5e6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport @@ -62,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 73152de24475..ed8bafb10c60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql._ import org.apache.spark.unsafe.types.UTF8String -import PartitioningUtils._ // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) -class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext { + import PartitioningUtils._ + import testImplicits._ val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala index 981334cf771c..b290429c2a02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SharedSQLContext -class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest { - override def sqlContext: SQLContext = TestSQLContext +class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { private def readParquetProtobufFile(name: String): DataFrame = { val url = Thread.currentThread().getContextClassLoader.getResource(name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 5e6d9c1cd44a..e2f2a8c74478 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -21,16 +21,15 @@ import java.io.File import org.apache.hadoop.fs.Path -import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuite extends QueryTest with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.sql +class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -41,22 +40,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) } - sqlContext.catalog.unregisterTable(Seq("tmp")) + ctx.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) + val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = sqlContext.read.parquet(file.getCanonicalPath) + val df2 = ctx.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length + ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length + ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 971f71e27bfc..9dcbc1a047be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,13 +22,11 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { - val sqlContext = TestSQLContext +abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 3c6e54db4bca..5dbc7d1630f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -22,9 +22,8 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,7 +32,9 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => +private[sql] trait ParquetTest extends SQLTestUtils { + protected def _sqlContext: SQLContext + /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. @@ -42,7 +43,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -54,7 +55,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(sqlContext.read.parquet(path))) + withParquetFile(data)(path => f(_sqlContext.read.parquet(path))) } /** @@ -66,14 +67,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetDataFrame(data) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + _sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 92b1d822172d..b789c5a106e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext -class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest { +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ - override val sqlContext: SQLContext = TestSQLContext - private val parquetFilePath = Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 239deb797384..22189477d277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.SharedSQLContext + +class DebuggingSuite extends SparkFunSuite with SharedSQLContext { -class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index d33a967093ca..4c9187a9a710 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -23,12 +23,12 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends SparkFunSuite { +class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { // Key is simply the record itself private val keyProjection = new Projection { @@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) @@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite { test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) @@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite { test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index ddff7cebcc17..cc649b9bd4c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,97 +17,19 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame} -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { - - private def testInnerJoin( - testName: String, - leftRows: DataFrame, - rightRows: DataFrame, - condition: Expression, - expectedAnswer: Seq[Product]): Unit = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - - def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { - val broadcastHashJoin = - execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right) - boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) - } - - def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = { - val shuffledHashJoin = - execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right) - val filteredJoin = - boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) - } - - def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = { - val sortMergeJoin = - execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right) - val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) - } - - test(s"$testName using BroadcastHashJoin (build=left)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeBroadcastHashJoin(left, right, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using BroadcastHashJoin (build=right)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeBroadcastHashJoin(left, right, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using ShuffledHashJoin (build=left)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeShuffledHashJoin(left, right, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using ShuffledHashJoin (build=right)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeShuffledHashJoin(left, right, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } +class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { - test(s"$testName using SortMergeJoin") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - makeSortMergeJoin(left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - } - - { - val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + private lazy val myUpperCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( Row(1, "A"), Row(2, "B"), Row(3, "C"), @@ -117,7 +39,8 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { Row(null, "G") )), new StructType().add("N", IntegerType).add("L", StringType)) - val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( + private lazy val myLowerCaseData = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( Row(1, "a"), Row(2, "b"), Row(3, "c"), @@ -125,21 +48,7 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { Row(null, "e") )), new StructType().add("n", IntegerType).add("l", StringType)) - testInnerJoin( - "inner join, one match per row", - upperCaseData, - lowerCaseData, - (upperCaseData.col("N") === lowerCaseData.col("n")).expr, - Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") - ) - ) - } - - private val testData2 = Seq( + private lazy val myTestData = Seq( (1, 1), (1, 2), (2, 1), @@ -148,14 +57,139 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { (3, 2) ).toDF("a", "b") + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testInnerJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: () => Expression, + expectedAnswer: Seq[Product]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) + ExtractEquiJoinKeys.unapply(join) + } + + def makeBroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val broadcastHashJoin = + execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + } + + def makeShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val shuffledHashJoin = + execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + def makeSortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan) = { + val sortMergeJoin = + execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) + val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) + EnsureRequirements(sqlContext).apply(filteredJoin) + } + + test(s"$testName using BroadcastHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using SortMergeJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + + testInnerJoin( + "inner join, one match per row", + myUpperCaseData, + myLowerCaseData, + () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + { - val left = testData2.where("a = 1") - val right = testData2.where("a = 1") + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 1") testInnerJoin( "inner join, multiple matches", left, right, - (left.col("a") === right.col("a")).expr, + () => (left.col("a") === right.col("a")).expr, Seq( (1, 1, 1, 1), (1, 1, 1, 2), @@ -166,13 +200,13 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils { } { - val left = testData2.where("a = 1") - val right = testData2.where("a = 2") + lazy val left = myTestData.where("a = 1") + lazy val right = myTestData.where("a = 2") testInnerJoin( "inner join, no matches", left, right, - (left.col("a") === right.col("a")).expr, + () => (left.col("a") === right.col("a")).expr, Seq.empty ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index e16f5e39aa2f..a1a617d7b739 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,28 +17,65 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} -import org.apache.spark.sql.{SQLConf, DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest} -class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { +class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(2, 100.0), + Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, 1.0), + Row(3, 3.0), + Row(5, 1.0), + Row(6, 6.0), + Row(null, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(0, 0.0), + Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches + Row(2, -1.0), + Row(2, -1.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(5, 3.0), + Row(7, 7.0), + Row(null, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable private def testOuterJoin( testName: String, - leftRows: DataFrame, - rightRows: DataFrame, + leftRows: => DataFrame, + rightRows: => DataFrame, joinType: JoinType, - condition: Expression, + condition: => Expression, expectedAnswer: Seq[Product]): Unit = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - test(s"$testName using ShuffledHashOuterJoin") { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(sqlContext).apply( @@ -46,19 +83,23 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { expectedAnswer.map(Row.fromTuple), sortAnswers = true) } - } + } + } - if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { + if (joinType != FullOuter) { + test(s"$testName using BroadcastHashOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } - } + } + } - test(s"$testName using SortMergeOuterJoin") { + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(sqlContext).apply( @@ -66,57 +107,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils { expectedAnswer.map(Row.fromTuple), sortAnswers = false) } - } } - } - - test(s"$testName using BroadcastNestedLoopJoin (build=left)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) } } - - test(s"$testName using BroadcastNestedLoopJoin (build=right)") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(1, 2.0), - Row(2, 100.0), - Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches - Row(2, 1.0), - Row(3, 3.0), - Row(5, 1.0), - Row(6, 6.0), - Row(null, null) - )), new StructType().add("a", IntegerType).add("b", DoubleType)) - - val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(0, 0.0), - Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches - Row(2, -1.0), - Row(2, -1.0), - Row(2, 3.0), - Row(3, 2.0), - Row(4, 1.0), - Row(5, 3.0), - Row(7, 7.0), - Row(null, null) - )), new StructType().add("c", IntegerType).add("d", DoubleType)) - - val condition = { - And( - (left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) } // --- Basic outer joins ------------------------------------------------------------------------ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 4503ed251fcb..baa86e320d98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,44 +17,80 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.{SQLConf, DataFrame, Row} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} -import org.apache.spark.sql.{SQLConf, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { -class SemiJoinSuite extends SparkPlanTest with SQLTestUtils { + private lazy val left = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + private lazy val right = ctx.createDataFrame( + ctx.sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable private def testLeftSemiJoin( testName: String, - leftRows: DataFrame, - rightRows: DataFrame, - condition: Expression, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, expectedAnswer: Seq[Product]): Unit = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join).foreach { - case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) => - test(s"$testName using LeftSemiJoinHash") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext).apply( - LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using LeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext).apply( + LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } + } - test(s"$testName using BroadcastLeftSemiJoinHash") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + test(s"$testName using BroadcastLeftSemiJoinHash") { + extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } + } } test(s"$testName using LeftSemiJoinBNL") { @@ -67,33 +103,6 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils { } } - val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(1, 2.0), - Row(1, 2.0), - Row(2, 1.0), - Row(2, 1.0), - Row(3, 3.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("a", IntegerType).add("b", DoubleType)) - - val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq( - Row(2, 3.0), - Row(2, 3.0), - Row(3, 2.0), - Row(4, 1.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("c", IntegerType).add("d", DoubleType)) - - val condition = { - And( - (left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) - } - testLeftSemiJoin( "basic test", left, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 7383d3f8fe02..80006bf077fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -28,17 +28,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestSQLContext - - import sqlContext.implicits._ +class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("LongSQLMetric should not box Long") { - val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") + val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long") val f = () => { l += 1L l.add(1L) @@ -52,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. - val l = TestSQLContext.sparkContext.accumulator(0L) + val l = ctx.sparkContext.accumulator(0L) val f = () => { l += 1L } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() @@ -73,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + val previousExecutionIds = ctx.listener.executionIdToData.keySet df.collect() - TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + ctx.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + val jobs = ctx.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + val metricValues = ctx.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => @@ -111,7 +109,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { SQLConf.TUNGSTEN_ENABLED.key -> "false") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) - val df = TestData.person.select('name) + val df = person.select('name) testSparkPlanMetrics(df, 1, Map( 0L ->("Project", Map( "number of rows" -> 2L))) @@ -126,7 +124,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { SQLConf.TUNGSTEN_ENABLED.key -> "true") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = TestData.person.select('name) + val df = person.select('name) testSparkPlanMetrics(df, 1, Map( 0L ->("TungstenProject", Map( "number of rows" -> 2L))) @@ -137,7 +135,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("Filter metrics") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) - val df = TestData.person.filter('age < 25) + val df = person.filter('age < 25) testSparkPlanMetrics(df, 1, Map( 0L -> ("Filter", Map( "number of input rows" -> 2L, @@ -152,7 +150,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { SQLConf.TUNGSTEN_ENABLED.key -> "false") { // Assume the execution plan is // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) - val df = TestData.testData2.groupBy().count() // 2 partitions + val df = testData2.groupBy().count() // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> ("Aggregate", Map( "number of input rows" -> 6L, @@ -163,7 +161,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { ) // 2 partitions and each partition contains 2 keys - val df2 = TestData.testData2.groupBy('a).count() + val df2 = testData2.groupBy('a).count() testSparkPlanMetrics(df2, 1, Map( 2L -> ("Aggregate", Map( "number of input rows" -> 6L, @@ -185,7 +183,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // Assume the execution plan is // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> // SortBasedAggregate(nodeId = 0) - val df = TestData.testData2.groupBy().count() // 2 partitions + val df = testData2.groupBy().count() // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> ("SortBasedAggregate", Map( "number of input rows" -> 6L, @@ -199,7 +197,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) // 2 partitions and each partition contains 2 keys - val df2 = TestData.testData2.groupBy('a).count() + val df2 = testData2.groupBy('a).count() testSparkPlanMetrics(df2, 1, Map( 3L -> ("SortBasedAggregate", Map( "number of input rows" -> 6L, @@ -219,7 +217,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // Assume the execution plan is // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) // -> TungstenAggregate(nodeId = 0) - val df = TestData.testData2.groupBy().count() // 2 partitions + val df = testData2.groupBy().count() // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> ("TungstenAggregate", Map( "number of input rows" -> 6L, @@ -230,7 +228,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { ) // 2 partitions and each partition contains 2 keys - val df2 = TestData.testData2.groupBy('a).count() + val df2 = testData2.groupBy('a).count() testSparkPlanMetrics(df2, 1, Map( 2L -> ("TungstenAggregate", Map( "number of input rows" -> 6L, @@ -246,7 +244,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -268,7 +266,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, // this test should use the deterministic number of partitions. withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -314,7 +312,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("ShuffledHashJoin metrics") { withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -390,7 +388,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("BroadcastNestedLoopJoin metrics") { withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -458,7 +456,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { } test("CartesianProduct metrics") { - val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") withTempTable("testDataForJoin") { // Assume the execution plan is @@ -476,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils { test("save metrics") { withTempPath { file => - val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet + val previousExecutionIds = ctx.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) - TestData.person.select('name).write.format("json").save(file.getAbsolutePath) - TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + person.select('name).write.format("json").save(file.getAbsolutePath) + ctx.sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs + val jobs = ctx.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId) + val metricValues = ctx.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq === Seq(2L)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 41dd1896c15d..80d1e8895694 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -25,12 +25,12 @@ import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext -class SQLListenerSuite extends SparkFunSuite { +class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ private def createTestDataFrame: DataFrame = { - import TestSQLContext.implicits._ Seq( (1, 1), (2, 2) @@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("basic") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(TestSQLContext) + val listener = new SQLListener(ctx) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index e4dcf4c75d20..0edac0848c3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,10 +25,13 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + import testImplicits._ + val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null @@ -42,10 +45,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { Some(StringType) } - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 84b52ca2c733..5dc3a2c07b8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -23,11 +23,13 @@ import java.util.Properties import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{SaveMode, Row} +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { + val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -37,10 +39,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql - before { Utils.classForName("org.h2.Driver") conn = DriverManager.getConnection(url) @@ -58,14 +56,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - ctx.sql( + sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc @@ -144,14 +142,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { } test("INSERT to JDBC Datasource") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 562c27906704..9bc3f6bcf6fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -19,28 +19,32 @@ package org.apache.spark.sql.sources import java.io.{File, IOException} -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DDLException +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql +class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null + private var path: File = null override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jt") + try { + caseInsensitiveContext.dropTempTable("jt") + } finally { + super.afterAll() + } } after { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 392da0b0826b..853707c036c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType} // please note that the META-INF/services had to be modified for the test directory for this to work -class DDLSourceLoadSuite extends DataSourceTest { +class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { test("data sources with the same name") { intercept[RuntimeException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 84855ce45e91..5f8514e1a241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -68,10 +69,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } } -class DDLTestSuite extends DataSourceTest { +class DDLTestSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE ddlPeople |USING org.apache.spark.sql.sources.DDLScanSource @@ -105,7 +108,7 @@ class DDLTestSuite extends DataSourceTest { )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = caseInsensitiveContext.sql("describe ddlPeople") + val attributes = sql("describe ddlPeople") .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 00cc7d5ea580..d74d29fb0beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,18 +17,23 @@ package org.apache.spark.sql.sources -import org.scalatest.BeforeAndAfter - import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext -abstract class DataSourceTest extends QueryTest with BeforeAndAfter { +private[sql] abstract class DataSourceTest extends QueryTest { + protected def _sqlContext: SQLContext + // We want to test some edge cases. - protected implicit lazy val caseInsensitiveContext = { - val ctx = new SQLContext(TestSQLContext.sparkContext) + protected lazy val caseInsensitiveContext: SQLContext = { + val ctx = new SQLContext(_sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } + protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { + test(sqlString) { + checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 5ef365797eac..c81c3d398280 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -96,11 +97,11 @@ object FiltersPushed { var list: Seq[Filter] = Nil } -class FilteredScanSuite extends DataSourceTest { +class FilteredScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - import caseInsensitiveContext.sql - - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTenFiltered diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index cdbfaf6455fe..78bd3e558296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,20 +19,17 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - +class InsertSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var path: File = null + private var path: File = null override def beforeAll(): Unit = { + super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) caseInsensitiveContext.read.json(rdd).registerTempTable("jt") @@ -47,9 +44,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll(): Unit = { - caseInsensitiveContext.dropTempTable("jsonTable") - caseInsensitiveContext.dropTempTable("jt") - Utils.deleteRecursively(path) + try { + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") + Utils.deleteRecursively(path) + } finally { + super.afterAll() + } } test("Simple INSERT OVERWRITE a JSONRelation") { @@ -221,9 +222,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a * 2 FROM jsonTable"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) - checkAnswer( - sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), + assertCached(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2) + checkAnswer(sql( + "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Insert overwrite and keep the same schema. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index c86ddd7c83e5..79b6e9b45c00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -19,21 +19,21 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class PartitionedWriteSuite extends QueryTest { - import TestSQLContext.implicits._ +class PartitionedWriteSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("write many partitions") { val path = Utils.createTempDir() path.delete() - val df = TestSQLContext.range(100).select($"id", lit(1).as("data")) + val df = ctx.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - TestSQLContext.read.load(path.getCanonicalPath), + ctx.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest { val path = Utils.createTempDir() path.delete() - val base = TestSQLContext.range(100) + val base = ctx.range(100) val df = base.unionAll(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - TestSQLContext.read.load(path.getCanonicalPath), + ctx.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 0d5183444af7..a89c5f8007e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -21,6 +21,7 @@ import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class PrunedScanSource extends RelationProvider { @@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } } -class PrunedScanSuite extends DataSourceTest { +class PrunedScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ - before { - caseInsensitiveContext.sql( + override def beforeAll(): Unit = { + super.beforeAll() + sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution + val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 31730a3d3f8d..f18546b4c2d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -19,25 +19,22 @@ package org.apache.spark.sql.sources import java.io.File -import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - - import caseInsensitiveContext.sql - +class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val sparkContext = caseInsensitiveContext.sparkContext - - var originalDefaultSource: String = null - - var path: File = null - - var df: DataFrame = null + private var originalDefaultSource: String = null + private var path: File = null + private var df: DataFrame = null override def beforeAll(): Unit = { + super.beforeAll() originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() @@ -49,11 +46,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll(): Unit = { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + try { + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } finally { + super.afterAll() + } } after { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) Utils.deleteRecursively(path) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index e34e0956d1fd..12af8068c398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class DefaultSource extends SimpleScanSource @@ -95,8 +96,8 @@ case class AllDataTypesScan( } } -class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext.sql +class TableScanSuite extends DataSourceTest with SharedSQLContext { + protected override lazy val sql = caseInsensitiveContext.sql _ private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( @@ -122,7 +123,8 @@ class TableScanSuite extends DataSourceTest { Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}"))))) }.toSeq - before { + override def beforeAll(): Unit = { + super.beforeAll() sql( """ |CREATE TEMPORARY TABLE oneToTen @@ -303,9 +305,10 @@ class TableScanSuite extends DataSourceTest { sql("SELECT i * 2 FROM oneToTen"), (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) - checkAnswer( - sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + assertCached(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer(sql( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala new file mode 100644 index 000000000000..1374a97476ca --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} + +/** + * A collection of sample data used in SQL tests. + */ +private[sql] trait SQLTestData { self => + protected def _sqlContext: SQLContext + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + import internalImplicits._ + import SQLTestData._ + + // Note: all test data should be lazy because the SQLContext is not set up yet. + + protected lazy val testData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("testData") + df + } + + protected lazy val testData2: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, 2).toDF() + df.registerTempTable("testData2") + df + } + + protected lazy val testData3: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + TestData3(1, None) :: + TestData3(2, Some(2)) :: Nil).toDF() + df.registerTempTable("testData3") + df + } + + protected lazy val negativeData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() + df.registerTempTable("negativeData") + df + } + + protected lazy val largeAndSmallInts: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil).toDF() + df.registerTempTable("largeAndSmallInts") + df + } + + protected lazy val decimalData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toDF() + df.registerTempTable("decimalData") + df + } + + protected lazy val binaryData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + BinaryData("12".getBytes, 1) :: + BinaryData("22".getBytes, 5) :: + BinaryData("122".getBytes, 3) :: + BinaryData("121".getBytes, 2) :: + BinaryData("123".getBytes, 4) :: Nil).toDF() + df.registerTempTable("binaryData") + df + } + + protected lazy val upperCaseData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil).toDF() + df.registerTempTable("upperCaseData") + df + } + + protected lazy val lowerCaseData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil).toDF() + df.registerTempTable("lowerCaseData") + df + } + + protected lazy val arrayData: RDD[ArrayData] = { + val rdd = _sqlContext.sparkContext.parallelize( + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + rdd.toDF().registerTempTable("arrayData") + rdd + } + + protected lazy val mapData: RDD[MapData] = { + val rdd = _sqlContext.sparkContext.parallelize( + MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + MapData(Map(1 -> "a4", 2 -> "b4")) :: + MapData(Map(1 -> "a5")) :: Nil) + rdd.toDF().registerTempTable("mapData") + rdd + } + + protected lazy val repeatedData: RDD[StringData] = { + val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("repeatedData") + rdd + } + + protected lazy val nullableRepeatedData: RDD[StringData] = { + val rdd = _sqlContext.sparkContext.parallelize( + List.fill(2)(StringData(null)) ++ + List.fill(2)(StringData("test"))) + rdd.toDF().registerTempTable("nullableRepeatedData") + rdd + } + + protected lazy val nullInts: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("nullInts") + df + } + + protected lazy val allNulls: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: + NullInts(null) :: Nil).toDF() + df.registerTempTable("allNulls") + df + } + + protected lazy val nullStrings: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil).toDF() + df.registerTempTable("nullStrings") + df + } + + protected lazy val tableName: DataFrame = { + val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.registerTempTable("tableName") + df + } + + protected lazy val unparsedStrings: RDD[String] = { + _sqlContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) + } + + // An RDD with 4 elements and 8 partitions + protected lazy val withEmptyParts: RDD[IntField] = { + val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().registerTempTable("withEmptyParts") + rdd + } + + protected lazy val person: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil).toDF() + df.registerTempTable("person") + df + } + + protected lazy val salary: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil).toDF() + df.registerTempTable("salary") + df + } + + protected lazy val complexData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: + ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: + Nil).toDF() + df.registerTempTable("complexData") + df + } + + /** + * Initialize all test data such that all temp tables are properly registered. + */ + def loadTestData(): Unit = { + assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + testData + testData2 + testData3 + negativeData + largeAndSmallInts + decimalData + binaryData + upperCaseData + lowerCaseData + arrayData + mapData + repeatedData + nullableRepeatedData + nullInts + allNulls + nullStrings + tableName + unparsedStrings + withEmptyParts + person + salary + complexData + } +} + +/** + * Case classes used in test data. + */ +private[sql] object SQLTestData { + case class TestData(key: Int, value: String) + case class TestData2(a: Int, b: Int) + case class TestData3(a: Int, b: Option[Int]) + case class LargeAndSmallInts(a: Int, b: Int) + case class DecimalData(a: BigDecimal, b: BigDecimal) + case class BinaryData(a: Array[Byte], b: Int) + case class UpperCaseData(N: Int, L: String) + case class LowerCaseData(n: Int, l: String) + case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) + case class MapData(data: scala.collection.Map[Int, String]) + case class StringData(s: String) + case class IntField(i: Int) + case class NullInts(a: Integer) + case class NullStrings(n: Int, s: String) + case class TableName(tableName: String) + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 106669558977..cdd691e03589 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -21,15 +21,71 @@ import java.io.File import java.util.UUID import scala.util.Try +import scala.language.implicitConversions + +import org.apache.hadoop.conf.Configuration +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.util.Utils -trait SQLTestUtils { this: SparkFunSuite => - protected def sqlContext: SQLContext +/** + * Helper trait that should be extended by all SQL test suites. + * + * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * prepared in advance as well as all implicit conversions used extensively by dataframes. + * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * + * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtils + extends SparkFunSuite + with BeforeAndAfterAll + with SQLTestData { self => + + protected def _sqlContext: SQLContext + + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + // Shorthand for running a query using our SQLContext + protected lazy val sql = _sqlContext.sql _ + + /** + * A helper object for importing SQL implicits. + * + * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * This is because we create the [[SQLContext]] immediately before the first test is run, + * but the implicits import is needed in the constructor. + */ + protected object testImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self._sqlContext + } + + /** + * Materialize the test data immediately after the [[SQLContext]] is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } - protected def configuration = sqlContext.sparkContext.hadoopConfiguration + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * The Hadoop configuration used by the active [[SQLContext]]. + */ + protected def configuration: Configuration = { + _sqlContext.sparkContext.hadoopConfiguration + } /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -39,12 +95,12 @@ trait SQLTestUtils { this: SparkFunSuite => */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(_sqlContext.conf.setConfString) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => _sqlContext.conf.setConfString(key, value) + case (key, None) => _sqlContext.conf.unsetConf(key) } } } @@ -76,7 +132,7 @@ trait SQLTestUtils { this: SparkFunSuite => * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(sqlContext.dropTempTable) + try f finally tableNames.foreach(_sqlContext.dropTempTable) } /** @@ -85,7 +141,7 @@ trait SQLTestUtils { this: SparkFunSuite => protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - sqlContext.sql(s"DROP TABLE IF EXISTS $name") + _sqlContext.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -98,12 +154,12 @@ trait SQLTestUtils { this: SparkFunSuite => val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - sqlContext.sql(s"CREATE DATABASE $dbName") + _sqlContext.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -111,7 +167,15 @@ trait SQLTestUtils { this: SparkFunSuite => * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sql(s"USE $db") - try f finally sqlContext.sql(s"USE default") + _sqlContext.sql(s"USE $db") + try f finally _sqlContext.sql(s"USE default") + } + + /** + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier + * way to construct [[DataFrame]] directly out of local data without relying on implicits. + */ + protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + DataFrame(_sqlContext, plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala new file mode 100644 index 000000000000..3cfd822e2a74 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.SQLContext + + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + */ +private[sql] trait SharedSQLContext extends SQLTestUtils { + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _ctx: TestSQLContext = null + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected def ctx: TestSQLContext = _ctx + protected def sqlContext: TestSQLContext = _ctx + protected override def _sqlContext: SQLContext = _ctx + + /** + * Initialize the [[TestSQLContext]]. + */ + protected override def beforeAll(): Unit = { + if (_ctx == null) { + _ctx = new TestSQLContext + } + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + if (_ctx != null) { + _ctx.sparkContext.stop() + _ctx = null + } + } finally { + super.afterAll() + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala similarity index 54% rename from sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala rename to sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index b3a4231da91c..92ef2f7d74ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -17,40 +17,36 @@ package org.apache.spark.sql.test -import scala.language.implicitConversions - import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** A SQLContext that can be used for local testing. */ -class LocalSQLContext - extends SQLContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf() - .set("spark.sql.testkey", "true") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) { - - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() +import org.apache.spark.sql.{SQLConf, SQLContext} + + +/** + * A special [[SQLContext]] prepared for testing. + */ +private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => + + def this() { + this(new SparkContext("local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true"))) } + // Use fewer partitions to speed up testing + protected[sql] override def createSession(): SQLSession = new this.SQLSession() + + /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ protected[sql] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** Fewer partitions to speed up testing. */ override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) } } - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to - * construct [[DataFrame]] directly out of local data without relying on implicits. - */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(this, plan) + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() } + private object testData extends SQLTestData { + protected override def _sqlContext: SQLContext = self + } } - -object TestSQLContext extends LocalSQLContext - diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 806240e6de45..bf431cd6b026 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.hive.HiveContext import org.apache.spark.ui.SparkUICssErrorHandler class UISeleniumSuite @@ -36,7 +35,6 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ var server: HiveThriftServer2 = _ - var hc: HiveContext = _ val uiPort = 20000 + Random.nextInt(10000) override def mode: ServerMode.Value = ServerMode.binary diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 59e65ff97b8e..574624d501f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.sources.DataSourceTest import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.{Row, SaveMode, SQLContext} import org.apache.spark.{Logging, SparkFunSuite} @@ -53,7 +53,8 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { } class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + import testImplicits._ private val testDF = range(1, 3).select( ('id + 0.1) cast DecimalType(10, 3) as 'd1, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 1fa005d5f9a1..fe0db5228de1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.execution.datasources.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{QueryTest, Row, SQLContext} case class Cases(lower: String, UPPER: String) class HiveParquetSuite extends QueryTest with ParquetTest { - val sqlContext = TestHive - - import sqlContext._ + private val ctx = TestHive + override def _sqlContext: SQLContext = ctx test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { @@ -54,7 +53,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).registerTempTable("p") + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -67,7 +66,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - read.parquet(file.getCanonicalPath).registerTempTable("p") + ctx.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 7f36a483a396..20a50586d520 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -22,7 +22,6 @@ import java.io.{IOException, File} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapred.InvalidInputException import org.scalatest.BeforeAndAfterAll import org.apache.spark.Logging @@ -42,7 +41,8 @@ import org.apache.spark.util.Utils */ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll with Logging { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext var jsonFilePath: String = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 73852f13ad20..417e8b07917c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -22,9 +22,8 @@ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val sqlContext: SQLContext = TestHive - - import sqlContext.sql + override val _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext private val df = sqlContext.range(10).coalesce(1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 251e0324bfa5..13452e71a1b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.{Row, SQLConf, SQLContext} class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { import ParquetCompatibilityTest.makeNullable - override val sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext /** * Set the staging directory (and hence path to ignore Parquet files under) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 9b3ede43ee2d..7ee1c8d13aa3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.QueryTest case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7b5aa4763fd9..a312f8495824 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql._ -import org.scalatest.BeforeAndAfterAll import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ var originalUseAggregate2: Boolean = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 44c5b80392fa..11d7a872dff0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.test.SQLTestUtils * A set of tests that validates support for Hive Explain command. */ class HiveExplainSuite extends QueryTest with SQLTestUtils { - - def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 79a136ae6f61..8b8f520776e7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -66,7 +66,8 @@ class MyDialect extends DefaultParserDialect * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest with SQLTestUtils { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext test("UDTF") { sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 0875232aede3..9aca40f15ac1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StringType class ScriptTransformationSuite extends SparkPlanTest { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 145965388da0..f7ba20ff41d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => - lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive - + protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ import sqlContext.sparkContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 50f02432dacc..34d3434569f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -685,7 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest { * A collection of tests for parquet data with various forms of partitioning. */ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def sqlContext: SQLContext = TestHive + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext var partitionedTableDir: File = null var normalTableDir: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index e976125b3706..b4640b161628 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestHive + override def _sqlContext: SQLContext = TestHive + private val sqlContext = _sqlContext // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 2a69d331b6e5..af445626fbe4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -34,9 +34,8 @@ import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override lazy val sqlContext: SQLContext = TestHive - - import sqlContext.sql + override def _sqlContext: SQLContext = TestHive + protected val sqlContext = _sqlContext import sqlContext.implicits._ val dataSourceName: String From bd35385d53a6b039e0241e3e73092b8b0a8e455a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 13 Aug 2015 21:12:59 -0700 Subject: [PATCH 0239/1824] [SPARK-9945] [SQL] pageSize should be calculated from executor.memory Currently, pageSize of TungstenSort is calculated from driver.memory, it should use executor.memory instead. Also, in the worst case, the safeFactor could be 4 (because of rounding), increase it to 16. cc rxin Author: Davies Liu Closes #8175 from davies/page_size. --- .../org/apache/spark/shuffle/ShuffleMemoryManager.scala | 4 +++- .../main/scala/org/apache/spark/sql/execution/sort.scala | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 8c3a72644c38..a0d8abc2eecb 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -175,7 +175,9 @@ private[spark] object ShuffleMemoryManager { val minPageSize = 1L * 1024 * 1024 // 1MB val maxPageSize = 64L * minPageSize // 64MB val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() - val safetyFactor = 8 + // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case + val safetyFactor = 16 + // TODO(davies): don't round to next power of 2 val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) conf.getSizeAsBytes("spark.buffer.pageSize", default) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index e31693047012..40ef7c3b5353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution -import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines various sort operators. @@ -122,7 +122,6 @@ case class TungstenSort( protected override def doExecute(): RDD[InternalRow] = { val schema = child.schema val childOutput = child.output - val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes /** * Set up the sorter in each partition before computing the parent partition. @@ -143,6 +142,7 @@ case class TungstenSort( } } + val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes val sorter = new UnsafeExternalRowSorter( schema, ordering, prefixComparator, prefixComputer, pageSize) if (testSpillFrequency > 0) { From 7c7c7529a16c0e79778e522a3df82a0f1c3762a3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 13 Aug 2015 22:06:09 -0700 Subject: [PATCH 0240/1824] [MINOR] [SQL] Remove canEqual in Row As `InternalRow` does not extend `Row` now, I think we can remove it. Author: Liang-Chi Hsieh Closes #8170 from viirya/remove_canequal. --- .../main/scala/org/apache/spark/sql/Row.scala | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 40159aaf14d3..ec895af9c303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -364,31 +364,10 @@ trait Row extends Serializable { false } - /** - * Returns true if we can check equality for these 2 rows. - * Equality check between external row and internal row is not allowed. - * Here we do this check to prevent call `equals` on external row with internal row. - */ - protected def canEqual(other: Row) = { - // Note that `Row` is not only the interface of external row but also the parent - // of `InternalRow`, so we have to ensure `other` is not a internal row here to prevent - // call `equals` on external row with internal row. - // `InternalRow` overrides canEqual, and these two canEquals together makes sure that - // equality check between external Row and InternalRow will always fail. - // In the future, InternalRow should not extend Row. In that case, we can remove these - // canEqual methods. - !other.isInstanceOf[InternalRow] - } - override def equals(o: Any): Boolean = { if (!o.isInstanceOf[Row]) return false val other = o.asInstanceOf[Row] - if (!canEqual(other)) { - throw new UnsupportedOperationException( - "cannot check equality between external and internal rows") - } - if (other eq null) return false if (length != other.length) { From c8677d73666850b37ff937520e538650632ce304 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 14 Aug 2015 14:41:53 +0800 Subject: [PATCH 0241/1824] [SPARK-9958] [SQL] Make HiveThriftServer2Listener thread-safe and update the tab name to "JDBC/ODBC Server" This PR fixed the thread-safe issue of HiveThriftServer2Listener, and also changed the tab name to "JDBC/ODBC Server" since it's conflict with the new SQL tab. thriftserver Author: zsxwing Closes #8185 from zsxwing/SPARK-9958. --- .../hive/thriftserver/HiveThriftServer2.scala | 64 +++++++++++-------- .../thriftserver/ui/ThriftServerPage.scala | 32 +++++----- .../ui/ThriftServerSessionPage.scala | 38 +++++------ .../thriftserver/ui/ThriftServerTab.scala | 4 +- 4 files changed, 78 insertions(+), 60 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 2c9fa595b2da..dd9fef9206d0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -152,16 +152,26 @@ object HiveThriftServer2 extends Logging { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { server.stop() } - var onlineSessionNum: Int = 0 - val sessionList = new mutable.LinkedHashMap[String, SessionInfo] - val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] - val retainedStatements = - conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) - val retainedSessions = - conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) - var totalRunning = 0 - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + private var onlineSessionNum: Int = 0 + private val sessionList = new mutable.LinkedHashMap[String, SessionInfo] + private val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] + private val retainedStatements = conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) + private val retainedSessions = conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) + private var totalRunning = 0 + + def getOnlineSessionNum: Int = synchronized { onlineSessionNum } + + def getTotalRunning: Int = synchronized { totalRunning } + + def getSessionList: Seq[SessionInfo] = synchronized { sessionList.values.toSeq } + + def getSession(sessionId: String): Option[SessionInfo] = synchronized { + sessionList.get(sessionId) + } + + def getExecutionList: Seq[ExecutionInfo] = synchronized { executionList.values.toSeq } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { for { props <- Option(jobStart.properties) groupId <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) @@ -173,13 +183,15 @@ object HiveThriftServer2 extends Logging { } def onSessionCreated(ip: String, sessionId: String, userName: String = "UNKNOWN"): Unit = { - val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) - sessionList.put(sessionId, info) - onlineSessionNum += 1 - trimSessionIfNecessary() + synchronized { + val info = new SessionInfo(sessionId, System.currentTimeMillis, ip, userName) + sessionList.put(sessionId, info) + onlineSessionNum += 1 + trimSessionIfNecessary() + } } - def onSessionClosed(sessionId: String): Unit = { + def onSessionClosed(sessionId: String): Unit = synchronized { sessionList(sessionId).finishTimestamp = System.currentTimeMillis onlineSessionNum -= 1 trimSessionIfNecessary() @@ -190,7 +202,7 @@ object HiveThriftServer2 extends Logging { sessionId: String, statement: String, groupId: String, - userName: String = "UNKNOWN"): Unit = { + userName: String = "UNKNOWN"): Unit = synchronized { val info = new ExecutionInfo(statement, sessionId, System.currentTimeMillis, userName) info.state = ExecutionState.STARTED executionList.put(id, info) @@ -200,27 +212,29 @@ object HiveThriftServer2 extends Logging { totalRunning += 1 } - def onStatementParsed(id: String, executionPlan: String): Unit = { + def onStatementParsed(id: String, executionPlan: String): Unit = synchronized { executionList(id).executePlan = executionPlan executionList(id).state = ExecutionState.COMPILED } def onStatementError(id: String, errorMessage: String, errorTrace: String): Unit = { - executionList(id).finishTimestamp = System.currentTimeMillis - executionList(id).detail = errorMessage - executionList(id).state = ExecutionState.FAILED - totalRunning -= 1 - trimExecutionIfNecessary() + synchronized { + executionList(id).finishTimestamp = System.currentTimeMillis + executionList(id).detail = errorMessage + executionList(id).state = ExecutionState.FAILED + totalRunning -= 1 + trimExecutionIfNecessary() + } } - def onStatementFinish(id: String): Unit = { + def onStatementFinish(id: String): Unit = synchronized { executionList(id).finishTimestamp = System.currentTimeMillis executionList(id).state = ExecutionState.FINISHED totalRunning -= 1 trimExecutionIfNecessary() } - private def trimExecutionIfNecessary() = synchronized { + private def trimExecutionIfNecessary() = { if (executionList.size > retainedStatements) { val toRemove = math.max(retainedStatements / 10, 1) executionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => @@ -229,7 +243,7 @@ object HiveThriftServer2 extends Logging { } } - private def trimSessionIfNecessary() = synchronized { + private def trimSessionIfNecessary() = { if (sessionList.size > retainedSessions) { val toRemove = math.max(retainedSessions / 10, 1) sessionList.filter(_._2.finishTimestamp != 0).take(toRemove).foreach { s => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 10c83d8b27a2..e990bd06011f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -39,14 +39,16 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { val content = - generateBasicStats() ++ -
    ++ -

    - {listener.onlineSessionNum} session(s) are online, - running {listener.totalRunning} SQL statement(s) -

    ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + listener.synchronized { // make sure all parts in this page are consistent + generateBasicStats() ++ +
    ++ +

    + {listener.getOnlineSessionNum} session(s) are online, + running {listener.getTotalRunning} SQL statement(s) +

    ++ + generateSessionStatsTable() ++ + generateSQLStatsTable() + } UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) } @@ -65,11 +67,11 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(): Seq[Node] = { - val numStatement = listener.executionList.size + val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = listener.executionList.values + val dataRows = listener.getExecutionList def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -136,15 +138,15 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { - val dataRows = - listener.sessionList.values + val dataRows = sessionList val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/sql/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) + val sessionLink = "%s/%s/session?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 3b01afa603ce..af16cb31df18 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -40,21 +40,22 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def render(request: HttpServletRequest): Seq[Node] = { val parameterId = request.getParameter("id") require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val sessionStat = listener.sessionList.find(stat => { - stat._1 == parameterId - }).getOrElse(null) - require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") val content = - generateBasicStats() ++ -
    ++ -

    - User {sessionStat._2.userName}, - IP {sessionStat._2.ip}, - Session created at {formatDate(sessionStat._2.startTimestamp)}, - Total run {sessionStat._2.totalExecution} SQL -

    ++ - generateSQLStatsTable(sessionStat._2.sessionId) + listener.synchronized { // make sure all parts in this page are consistent + val sessionStat = listener.getSession(parameterId).getOrElse(null) + require(sessionStat != null, "Invalid sessionID[" + parameterId + "]") + + generateBasicStats() ++ +
    ++ +

    + User {sessionStat.userName}, + IP {sessionStat.ip}, + Session created at {formatDate(sessionStat.startTimestamp)}, + Total run {sessionStat.totalExecution} SQL +

    ++ + generateSQLStatsTable(sessionStat.sessionId) + } UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } @@ -73,13 +74,13 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch statements of the thrift server program */ private def generateSQLStatsTable(sessionID: String): Seq[Node] = { - val executionList = listener.executionList - .filter(_._2.sessionId == sessionID) + val executionList = listener.getExecutionList + .filter(_.sessionId == sessionID) val numStatement = executionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", "Statement", "State", "Detail") - val dataRows = executionList.values.toSeq.sortBy(_.startTimestamp).reverse + val dataRows = executionList.sortBy(_.startTimestamp).reverse def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => @@ -146,10 +147,11 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Generate stats of batch sessions of the thrift server program */ private def generateSessionStatsTable(): Seq[Node] = { - val numBatches = listener.sessionList.size + val sessionList = listener.getSessionList + val numBatches = sessionList.size val table = if (numBatches > 0) { val dataRows = - listener.sessionList.values.toSeq.sortBy(_.startTimestamp).reverse.map ( session => + sessionList.sortBy(_.startTimestamp).reverse.map ( session => Seq( session.userName, session.ip, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 94fd8a6bb60b..4eabeaa6735e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,9 +27,9 @@ import org.apache.spark.{SparkContext, Logging, SparkException} * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) - extends SparkUITab(getSparkUI(sparkContext), "sql") with Logging { + extends SparkUITab(getSparkUI(sparkContext), "sqlserver") with Logging { - override val name = "SQL" + override val name = "JDBC/ODBC Server" val parent = getSparkUI(sparkContext) val listener = HiveThriftServer2.listener From a0e1abbd010b9e73d472ce12ff1d987678005d32 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 14 Aug 2015 10:25:11 -0700 Subject: [PATCH 0242/1824] [SPARK-9661] [MLLIB] minor clean-up of SPARK-9661 Some minor clean-ups after SPARK-9661. See my inline comments. MechCoder jkbradley Author: Xiangrui Meng Closes #8190 from mengxr/SPARK-9661-fix. --- .../spark/mllib/clustering/LDAModel.scala | 5 +-- .../apache/spark/mllib/stat/Statistics.scala | 6 +-- .../spark/mllib/clustering/JavaLDASuite.java | 40 +++++++++++-------- .../spark/mllib/clustering/LDASuite.scala | 2 +- 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index f31949f13a4c..82f05e4a18ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -674,10 +674,9 @@ class DistributedLDAModel private[clustering] ( } /** Java-friendly version of [[topTopicsPerDocument]] */ - def javaTopTopicsPerDocument( - k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[java.lang.Double])] = { + def javaTopTopicsPerDocument(k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[Double])] = { val topics = topTopicsPerDocument(k) - topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[java.lang.Double])]].toJavaRDD() + topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Double])]].toJavaRDD() } // TODO: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 24fe48cb8f71..ef8d78607048 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -221,9 +221,7 @@ object Statistics { def kolmogorovSmirnovTest( data: JavaDoubleRDD, distName: String, - params: java.lang.Double*): KolmogorovSmirnovTestResult = { - val javaParams = params.asInstanceOf[Seq[Double]] - KolmogorovSmirnovTest.testOneSample(data.rdd.asInstanceOf[RDD[Double]], - distName, javaParams: _*) + params: Double*): KolmogorovSmirnovTestResult = { + kolmogorovSmirnovTest(data.rdd.asInstanceOf[RDD[Double]], distName, params: _*) } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 427be9430d82..6e91cde2eabb 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -22,12 +22,14 @@ import java.util.Arrays; import scala.Tuple2; +import scala.Tuple3; import org.junit.After; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertArrayEquals; import org.junit.Before; import org.junit.Test; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaPairRDD; @@ -44,9 +46,9 @@ public class JavaLDASuite implements Serializable { public void setUp() { sc = new JavaSparkContext("local", "JavaLDA"); ArrayList> tinyCorpus = new ArrayList>(); - for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { - tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), - LDASuite$.MODULE$.tinyCorpus()[i]._2())); + for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite.tinyCorpus()[i]._1(), + LDASuite.tinyCorpus()[i]._2())); } JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); @@ -60,7 +62,7 @@ public void tearDown() { @Test public void localLDAModel() { - Matrix topics = LDASuite$.MODULE$.tinyTopics(); + Matrix topics = LDASuite.tinyTopics(); double[] topicConcentration = new double[topics.numRows()]; Arrays.fill(topicConcentration, 1.0D / topics.numRows()); LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D); @@ -110,8 +112,8 @@ public void distributedLDAModel() { assertEquals(roundedLocalTopicSummary.length, k); // Check: log probabilities - assert(model.logLikelihood() < 0.0); - assert(model.logPrior() < 0.0); + assertTrue(model.logLikelihood() < 0.0); + assertTrue(model.logPrior() < 0.0); // Check: topic distributions JavaPairRDD topicDistributions = model.javaTopicDistributions(); @@ -126,8 +128,12 @@ public Boolean call(Tuple2 tuple2) { assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); // Check: javaTopTopicsPerDocuments - JavaRDD> topTopics = - model.javaTopTopicsPerDocument(3); + Tuple3 topTopics = model.javaTopTopicsPerDocument(3).first(); + Long docId = topTopics._1(); // confirm doc ID type + int[] topicIndices = topTopics._2(); + double[] topicWeights = topTopics._3(); + assertEquals(3, topicIndices.length); + assertEquals(3, topicWeights.length); } @Test @@ -177,18 +183,18 @@ public void localLdaMethods() { // check: logLikelihood. ArrayList> docsSingleWord = new ArrayList>(); - docsSingleWord.add(new Tuple2(Long.valueOf(0), Vectors.dense(1.0, 0.0, 0.0))); + docsSingleWord.add(new Tuple2(0L, Vectors.dense(1.0, 0.0, 0.0))); JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); double logLikelihood = toyModel.logLikelihood(single); } - private static int tinyK = LDASuite$.MODULE$.tinyK(); - private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); - private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + private static int tinyK = LDASuite.tinyK(); + private static int tinyVocabSize = LDASuite.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite.tinyTopics(); private static Tuple2[] tinyTopicDescription = - LDASuite$.MODULE$.tinyTopicDescription(); + LDASuite.tinyTopicDescription(); private JavaPairRDD corpus; - private LocalLDAModel toyModel = LDASuite$.MODULE$.toyModel(); - private ArrayList> toyData = LDASuite$.MODULE$.javaToyData(); + private LocalLDAModel toyModel = LDASuite.toyModel(); + private ArrayList> toyData = LDASuite.javaToyData(); } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 926185e90bcf..99e28499fd31 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -581,7 +581,7 @@ private[clustering] object LDASuite { def javaToyData: JArrayList[(java.lang.Long, Vector)] = { val javaData = new JArrayList[(java.lang.Long, Vector)] var i = 0 - while (i < toyData.size) { + while (i < toyData.length) { javaData.add((toyData(i)._1, toyData(i)._2)) i += 1 } From 7ecf0c46990c39df8aeddbd64ca33d01824bcc0a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 14 Aug 2015 10:48:02 -0700 Subject: [PATCH 0243/1824] [SPARK-9956] [ML] Make trees work with one-category features This modifies DecisionTreeMetadata construction to treat 1-category features as continuous, so that trees do not fail with such features. It is important for the pipelines API, where VectorIndexer can automatically categorize certain features as categorical. As stated in the JIRA, this is a temp fix which we can improve upon later by automatically filtering out those features. That will take longer, though, since it will require careful indexing. Targeted for 1.5 and master CC: manishamde mengxr yanboliang Author: Joseph K. Bradley Closes #8187 from jkbradley/tree-1cat. --- .../tree/impl/DecisionTreeMetadata.scala | 27 ++++++++++++------- .../DecisionTreeClassifierSuite.scala | 13 +++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 9fe264656ede..21ee49c45788 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -144,21 +144,28 @@ private[spark] object DecisionTreeMetadata extends Logging { val maxCategoriesForUnorderedFeature = ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. - // We do this check with log values to prevent overflows in case numCategories is large. - // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - if (numCategories <= maxCategoriesForUnorderedFeature) { - unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } } } } else { // Binary classification or regression strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - numBins(featureIndex) = numCategories + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 4b7c5d3f23d2..f680d8d3c4cc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -261,6 +261,19 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } } + test("training with 1-category categorical feature") { + val data = sc.parallelize(Seq( + LabeledPoint(0, Vectors.dense(0, 2, 3)), + LabeledPoint(1, Vectors.dense(0, 3, 1)), + LabeledPoint(0, Vectors.dense(0, 2, 2)), + LabeledPoint(1, Vectors.dense(0, 3, 9)), + LabeledPoint(0, Vectors.dense(0, 2, 6)) + )) + val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) + val dt = new DecisionTreeClassifier().setMaxDepth(3) + val model = dt.fit(df) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// From a7317ccdc20d001e5b7f5277b0535923468bfbc6 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 14 Aug 2015 11:22:10 -0700 Subject: [PATCH 0244/1824] [SPARK-8744] [ML] Add a public constructor to StringIndexer It would be helpful to allow users to pass a pre-computed index to create an indexer, rather than always going through StringIndexer to create the model. Author: Holden Karau Closes #7267 from holdenk/SPARK-8744-StringIndexerModel-should-have-public-constructor. --- .../scala/org/apache/spark/ml/feature/StringIndexer.scala | 4 +++- .../org/apache/spark/ml/feature/StringIndexerSuite.scala | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 9f6e7b6b6b27..63475780a6ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -102,10 +102,12 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod * This is a temporary fix for the case when target labels do not exist during prediction. */ @Experimental -class StringIndexerModel private[ml] ( +class StringIndexerModel ( override val uid: String, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) + private val labelToIndex: OpenHashMap[String, Double] = { val n = labels.length val map = new OpenHashMap[String, Double](n) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index fa918ce64877..0b4c8ba71ee6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -30,7 +30,9 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new StringIndexer) val model = new StringIndexerModel("indexer", Array("a", "b")) + val modelWithoutUid = new StringIndexerModel(Array("a", "b")) ParamsSuite.checkParams(model) + ParamsSuite.checkParams(modelWithoutUid) } test("StringIndexer") { From 34d610be854d2a975d9c1e232d87433b85add6fd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 14 Aug 2015 12:00:01 -0700 Subject: [PATCH 0245/1824] [SPARK-9929] [SQL] support metadata in withColumn in MLlib sometimes we need to set metadata for the new column, thus we will alias the new column with metadata before call `withColumn` and in `withColumn` we alias this clolumn again. Here I overloaded `withColumn` to allow user set metadata, just like what we did for `Column.as`. Author: Wenchen Fan Closes #8159 from cloud-fan/withColumn. --- .../spark/ml/classification/OneVsRest.scala | 6 +++--- .../apache/spark/ml/feature/Bucketizer.scala | 2 +- .../apache/spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/VectorSlicer.scala | 3 +-- .../scala/org/apache/spark/sql/DataFrame.scala | 17 +++++++++++++++++ 5 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1132d8046df6..c62e132f5d53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) + .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) .drop(accColName) } @@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String) // TODO: use when ... otherwise after SPARK-7321 is merged val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) - val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) + val trainingDataset = + multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta) val classifier = getClassifier val paramMap = new ParamMap() paramMap.put(classifier.labelCol -> labelColName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index cfca494dcf46..6fdf25b015b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String) } val newCol = bucketizer(dataset($(inputCol))) val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 6875aefe065b..61b925c0fdc0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] ( val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 772bebeff214..c5c227227079 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String) case features: SparseVector => features.slice(inds) } } - dataset.withColumn($(outputCol), - slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata())) + dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata()) } /** Get the feature indices in order: indices, names */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c466d9e6cb34..cf75e64e884b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1149,6 +1149,23 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName, metadata) else Column(name) + } + select(colNames : _*) + } else { + select(Column("*"), col.as(colName, metadata)) + } + } + /** * Returns a new [[DataFrame]] with a column renamed. * This is a no-op if schema doesn't contain existingName. From 57c2d08800a506614b461b5a505a1dd1a28a8908 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Fri, 14 Aug 2015 20:03:50 +0100 Subject: [PATCH 0246/1824] [SPARK-9923] [CORE] ShuffleMapStage.numAvailableOutputs should be an Int instead of Long Modified type of ShuffleMapStage.numAvailableOutputs from Long to Int Author: Neelesh Srinivas Salian Closes #8183 from nssalian/SPARK-9923. --- .../main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 66c75f325fcd..48d8d8e9c4b7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -37,7 +37,7 @@ private[spark] class ShuffleMapStage( override def toString: String = "ShuffleMapStage " + id - var numAvailableOutputs: Long = 0 + var numAvailableOutputs: Int = 0 def isAvailable: Boolean = numAvailableOutputs == numPartitions From 3bc55287220b1248e935bf817d880ff176ad4d3b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Aug 2015 12:32:35 -0700 Subject: [PATCH 0247/1824] [SPARK-9946] [SPARK-9589] [SQL] fix NPE and thread-safety in TaskMemoryManager Currently, we access the `page.pageNumer` after it's freed, that could be modified by other thread, cause NPE. The same TaskMemoryManager could be used by multiple threads (for example, Python UDF and TransportScript), so it should be thread safe to allocate/free memory/page. The underlying Bitset and HashSet are not thread safe, we should put them inside a synchronized block. cc JoshRosen Author: Davies Liu Closes #8177 from davies/memory_manager. --- .../unsafe/memory/TaskMemoryManager.java | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 358bb3725015..ca70d7f4a431 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -144,14 +144,16 @@ public MemoryBlock allocatePage(long size) { public void freePage(MemoryBlock page) { assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - executorMemoryManager.free(page); + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; synchronized (this) { allocatedPages.clear(page.pageNumber); } - pageTable[page.pageNumber] = null; if (logger.isTraceEnabled()) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } + // Cannot access a page once it's freed. + executorMemoryManager.free(page); } /** @@ -166,7 +168,9 @@ public void freePage(MemoryBlock page) { public MemoryBlock allocate(long size) throws OutOfMemoryError { assert(size > 0) : "Size must be positive, but got " + size; final MemoryBlock memory = executorMemoryManager.allocate(size); - allocatedNonPageMemory.add(memory); + synchronized(allocatedNonPageMemory) { + allocatedNonPageMemory.add(memory); + } return memory; } @@ -176,8 +180,10 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; executorMemoryManager.free(memory); - final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); - assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + synchronized(allocatedNonPageMemory) { + final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); + assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + } } /** @@ -223,9 +229,10 @@ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - final Object page = pageTable[pageNumber].getBaseObject(); + final MemoryBlock page = pageTable[pageNumber]; assert (page != null); - return page; + assert (page.getBaseObject() != null); + return page.getBaseObject(); } else { return null; } @@ -244,7 +251,9 @@ public long getOffsetInPage(long pagePlusOffsetAddress) { // converted the absolute address into a relative address. Here, we invert that operation: final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - return pageTable[pageNumber].getBaseOffset() + offsetInPage; + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + return page.getBaseOffset() + offsetInPage; } } @@ -260,14 +269,17 @@ public long cleanUpAllAllocatedMemory() { freePage(page); } } - final Iterator iter = allocatedNonPageMemory.iterator(); - while (iter.hasNext()) { - final MemoryBlock memory = iter.next(); - freedBytes += memory.size(); - // We don't call free() here because that calls Set.remove, which would lead to a - // ConcurrentModificationException here. - executorMemoryManager.free(memory); - iter.remove(); + + synchronized (allocatedNonPageMemory) { + final Iterator iter = allocatedNonPageMemory.iterator(); + while (iter.hasNext()) { + final MemoryBlock memory = iter.next(); + freedBytes += memory.size(); + // We don't call free() here because that calls Set.remove, which would lead to a + // ConcurrentModificationException here. + executorMemoryManager.free(memory); + iter.remove(); + } } return freedBytes; } From ece00566e4d5f38585f2810bef38e526cae7d25e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 14 Aug 2015 12:37:21 -0700 Subject: [PATCH 0248/1824] [SPARK-9561] Re-enable BroadcastJoinSuite We can do this now that SPARK-9580 is resolved. Author: Andrew Or Closes #8208 from andrewor14/reenable-sql-tests. --- .../execution/joins/BroadcastJoinSuite.scala | 145 +++++++++--------- 1 file changed, 69 insertions(+), 76 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 0554e11d252b..53a0e53fd771 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -15,80 +15,73 @@ * limitations under the License. */ -// TODO: uncomment the test here! It is currently failing due to -// bad interaction with org.apache.spark.sql.test.TestSQLContext. +package org.apache.spark.sql.execution.joins -// scalastyle:off -//package org.apache.spark.sql.execution.joins -// -//import scala.reflect.ClassTag -// -//import org.scalatest.BeforeAndAfterAll -// -//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -//import org.apache.spark.sql.functions._ -//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} -// -///** -// * Test various broadcast join operators with unsafe enabled. -// * -// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs -// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode. -// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in -// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without -// * serializing the hashed relation, which does not happen in local mode. -// */ -//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { -// private var sc: SparkContext = null -// private var sqlContext: SQLContext = null -// -// /** -// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. -// */ -// override def beforeAll(): Unit = { -// super.beforeAll() -// val conf = new SparkConf() -// .setMaster("local-cluster[2,1,1024]") -// .setAppName("testing") -// sc = new SparkContext(conf) -// sqlContext = new SQLContext(sc) -// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) -// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) -// } -// -// override def afterAll(): Unit = { -// sc.stop() -// sc = null -// sqlContext = null -// } -// -// /** -// * Test whether the specified broadcast join updates the peak execution memory accumulator. -// */ -// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { -// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { -// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") -// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") -// // Comparison at the end is for broadcast left semi join -// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") -// val df3 = df1.join(broadcast(df2), joinExpression, joinType) -// val plan = df3.queryExecution.executedPlan -// assert(plan.collect { case p: T => p }.size === 1) -// plan.executeCollect() -// } -// } -// -// test("unsafe broadcast hash join updates peak execution memory") { -// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") -// } -// -// test("unsafe broadcast hash outer join updates peak execution memory") { -// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") -// } -// -// test("unsafe broadcast left semi join updates peak execution memory") { -// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") -// } -// -//} -// scalastyle:on +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} + +/** + * Test various broadcast join operators with unsafe enabled. + * + * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of + * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered + * without serializing the hashed relation, which does not happen in local mode. + */ +class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { + private var sc: SparkContext = null + private var sqlContext: SQLContext = null + + /** + * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .setMaster("local-cluster[2,1,1024]") + .setAppName("testing") + sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) + sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + sc.stop() + sc = null + sqlContext = null + } + + /** + * Test whether the specified broadcast join updates the peak execution memory accumulator. + */ + private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { + val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + // Comparison at the end is for broadcast left semi join + val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") + val df3 = df1.join(broadcast(df2), joinExpression, joinType) + val plan = df3.queryExecution.executedPlan + assert(plan.collect { case p: T => p }.size === 1) + plan.executeCollect() + } + } + + test("unsafe broadcast hash join updates peak execution memory") { + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") + } + + test("unsafe broadcast hash outer join updates peak execution memory") { + testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + } + + test("unsafe broadcast left semi join updates peak execution memory") { + testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") + } + +} From ffa05c84fe75663fc33f3d954d1cb1e084ab3280 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 14 Aug 2015 12:46:05 -0700 Subject: [PATCH 0249/1824] [SPARK-9828] [PYSPARK] Mutable values should not be default arguments Author: MechCoder Closes #8110 from MechCoder/spark-9828. --- python/pyspark/ml/evaluation.py | 4 +++- python/pyspark/ml/param/__init__.py | 26 +++++++++++++++++--------- python/pyspark/ml/pipeline.py | 4 ++-- python/pyspark/ml/tuning.py | 8 ++++++-- python/pyspark/rdd.py | 5 ++++- python/pyspark/sql/readwriter.py | 8 ++++++-- python/pyspark/statcounter.py | 4 +++- python/pyspark/streaming/kafka.py | 12 +++++++++--- 8 files changed, 50 insertions(+), 21 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 2734092575ad..e23ce053baeb 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -46,7 +46,7 @@ def _evaluate(self, dataset): """ raise NotImplementedError() - def evaluate(self, dataset, params={}): + def evaluate(self, dataset, params=None): """ Evaluates the output with optional parameters. @@ -56,6 +56,8 @@ def evaluate(self, dataset, params={}): params :return: metric """ + if params is None: + params = dict() if isinstance(params, dict): if params: return self.copy(params)._evaluate(dataset) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 7845536161e0..eeeac49b2198 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -60,14 +60,16 @@ class Params(Identifiable): __metaclass__ = ABCMeta - #: internal param map for user-supplied values param map - _paramMap = {} + def __init__(self): + super(Params, self).__init__() + #: internal param map for user-supplied values param map + self._paramMap = {} - #: internal param map for default values - _defaultParamMap = {} + #: internal param map for default values + self._defaultParamMap = {} - #: value returned by :py:func:`params` - _params = None + #: value returned by :py:func:`params` + self._params = None @property def params(self): @@ -155,7 +157,7 @@ def getOrDefault(self, param): else: return self._defaultParamMap[param] - def extractParamMap(self, extra={}): + def extractParamMap(self, extra=None): """ Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into @@ -165,12 +167,14 @@ def extractParamMap(self, extra={}): :param extra: extra param values :return: merged param map """ + if extra is None: + extra = dict() paramMap = self._defaultParamMap.copy() paramMap.update(self._paramMap) paramMap.update(extra) return paramMap - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with the same uid and some extra params. The default implementation creates a @@ -181,6 +185,8 @@ def copy(self, extra={}): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() that = copy.copy(self) that._paramMap = self.extractParamMap(extra) return that @@ -233,7 +239,7 @@ def _setDefault(self, **kwargs): self._defaultParamMap[getattr(self, param)] = value return self - def _copyValues(self, to, extra={}): + def _copyValues(self, to, extra=None): """ Copies param values from this instance to another instance for params shared by them. @@ -241,6 +247,8 @@ def _copyValues(self, to, extra={}): :param extra: extra params to be copied :return: the target instance with param values copied """ + if extra is None: + extra = dict() paramMap = self.extractParamMap(extra) for p in self.params: if p in paramMap and to.hasParam(p.name): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 9889f56cac9e..13cf2b0f7bbd 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -141,7 +141,7 @@ class Pipeline(Estimator): @keyword_only def __init__(self, stages=None): """ - __init__(self, stages=[]) + __init__(self, stages=None) """ if stages is None: stages = [] @@ -170,7 +170,7 @@ def getStages(self): @keyword_only def setParams(self, stages=None): """ - setParams(self, stages=[]) + setParams(self, stages=None) Sets params for Pipeline. """ if stages is None: diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0bf988fd72f1..dcfee6a3170a 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -227,7 +227,9 @@ def _fit(self, dataset): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() newCV = Params.copy(self, extra) if self.isSet(self.estimator): newCV.setEstimator(self.getEstimator().copy(extra)) @@ -250,7 +252,7 @@ def __init__(self, bestModel): def _transform(self, dataset): return self.bestModel.transform(dataset) - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, @@ -259,6 +261,8 @@ def copy(self, extra={}): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() return CrossValidatorModel(self.bestModel.copy(extra)) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fa8e0a0574a6..9ef60a7e2c84 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -700,7 +700,7 @@ def groupBy(self, f, numPartitions=None): return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) @ignore_unicode_prefix - def pipe(self, command, env={}, checkCode=False): + def pipe(self, command, env=None, checkCode=False): """ Return an RDD created by piping elements to a forked external process. @@ -709,6 +709,9 @@ def pipe(self, command, env={}, checkCode=False): :param checkCode: whether or not to check the return value of the shell command. """ + if env is None: + env = dict() + def func(iterator): pipe = Popen( shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index bf6ac084bbbf..78247c8fa737 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -182,7 +182,7 @@ def orc(self, path): @since(1.4) def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, - predicates=None, properties={}): + predicates=None, properties=None): """ Construct a :class:`DataFrame` representing the database table accessible via JDBC URL `url` named `table` and connection `properties`. @@ -208,6 +208,8 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar should be included. :return: a DataFrame """ + if properties is None: + properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) @@ -427,7 +429,7 @@ def orc(self, path, mode=None, partitionBy=None): self._jwrite.orc(path) @since(1.4) - def jdbc(self, url, table, mode=None, properties={}): + def jdbc(self, url, table, mode=None, properties=None): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. .. note:: Don't create too many partitions in parallel on a large cluster;\ @@ -445,6 +447,8 @@ def jdbc(self, url, table, mode=None, properties={}): arbitrary string tag/value. Normally at least a "user" and "password" property should be included. """ + if properties is None: + properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 944fa414b0c0..0fee3b209682 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -30,7 +30,9 @@ class StatCounter(object): - def __init__(self, values=[]): + def __init__(self, values=None): + if values is None: + values = list() self.n = 0 # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 33dd596335b4..dc5b7fd878ae 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -35,7 +35,7 @@ def utf8_decoder(s): class KafkaUtils(object): @staticmethod - def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, + def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ @@ -52,6 +52,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object """ + if kafkaParams is None: + kafkaParams = dict() kafkaParams.update({ "zookeeper.connect": zkQuorum, "group.id": groupId, @@ -77,7 +79,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) @staticmethod - def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, + def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ .. note:: Experimental @@ -105,6 +107,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, :param valueDecoder: A function used to decode value (default is utf8_decoder). :return: A DStream object """ + if fromOffsets is None: + fromOffsets = dict() if not isinstance(topics, list): raise TypeError("topics should be list") if not isinstance(kafkaParams, dict): @@ -129,7 +133,7 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer) @staticmethod - def createRDD(sc, kafkaParams, offsetRanges, leaders={}, + def createRDD(sc, kafkaParams, offsetRanges, leaders=None, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ .. note:: Experimental @@ -145,6 +149,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A RDD object """ + if leaders is None: + leaders = dict() if not isinstance(kafkaParams, dict): raise TypeError("kafkaParams should be dict") if not isinstance(offsetRanges, list): From 33bae585d4cb25aed2ac32e0d1248f78cc65318b Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Fri, 14 Aug 2015 13:38:25 -0700 Subject: [PATCH 0250/1824] [SPARK-9809] Task crashes because the internal accumulators are not properly initialized When a stage failed and another stage was resubmitted with only part of partitions to compute, all the tasks failed with error message: java.util.NoSuchElementException: key not found: peakExecutionMemory. This is because the internal accumulators are not properly initialized for this stage while other codes assume the internal accumulators always exist. Author: Carson Wang Closes #8090 from carsonwang/SPARK-9809. --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7ab5ccf50adb..f1c63d08766c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -790,9 +790,10 @@ class DAGScheduler( } } + // Create internal accumulators if the stage has no accumulators initialized. // Reset internal accumulators only if this stage is not partially submitted // Otherwise, we may override existing accumulator values from some tasks - if (allPartitions == partitionsToCompute) { + if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) { stage.resetInternalAccumulators() } From 6518ef63037aa56b541927f99ad26744f91098ce Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 14 Aug 2015 13:42:53 -0700 Subject: [PATCH 0251/1824] [SPARK-9948] Fix flaky AccumulatorSuite - internal accumulators In these tests, we use a custom listener and we assert on fields in the stage / task completion events. However, these events are posted in a separate thread so they're not guaranteed to be posted in time. This commit fixes this flakiness through a job end registration callback. Author: Andrew Or Closes #8176 from andrewor14/fix-accumulator-suite. --- .../org/apache/spark/AccumulatorSuite.scala | 153 +++++++++++------- 1 file changed, 92 insertions(+), 61 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 0eb2293a9d06..5b84acf40be4 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -182,26 +182,30 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex sc = new SparkContext("local", "test") sc.addSparkListener(listener) // Have each task add 1 to the internal accumulator - sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 iter - }.count() - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions) - // The accumulator values should be merged in the stage - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - assert(stageAccum.value.toLong === numPartitions) - // The accumulator should be updated locally on each task - val taskAccumValues = taskInfos.map { taskInfo => - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - taskAccum.value.toLong } - // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions) + // The accumulator values should be merged in the stage + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + assert(stageAccum.value.toLong === numPartitions) + // The accumulator should be updated locally on each task + val taskAccumValues = taskInfos.map { taskInfo => + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + taskAccum.value.toLong + } + // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + } + rdd.count() } test("internal accumulators in multiple stages") { @@ -211,7 +215,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.addSparkListener(listener) // Each stage creates its own set of internal accumulators so the // values for the same metric should not be mixed up across stages - sc.parallelize(1 to 100, numPartitions) + val rdd = sc.parallelize(1 to 100, numPartitions) .map { i => (i, i) } .mapPartitions { iter => TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 @@ -227,16 +231,20 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100 iter } - .count() - // We ran 3 stages, and the accumulator values should be distinct - val stageInfos = listener.getCompletedStageInfos - assert(stageInfos.size === 3) - val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR) - val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR) - val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR) - assert(firstStageAccum.value.toLong === numPartitions) - assert(secondStageAccum.value.toLong === numPartitions * 10) - assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + // We ran 3 stages, and the accumulator values should be distinct + val stageInfos = listener.getCompletedStageInfos + assert(stageInfos.size === 3) + val (firstStageAccum, secondStageAccum, thirdStageAccum) = + (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR), + findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)) + assert(firstStageAccum.value.toLong === numPartitions) + assert(secondStageAccum.value.toLong === numPartitions * 10) + assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100) + } + rdd.count() } test("internal accumulators in fully resubmitted stages") { @@ -268,7 +276,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex // This says use 1 core and retry tasks up to 2 times sc = new SparkContext("local[1, 2]", "test") sc.addSparkListener(listener) - sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => + val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) => val taskContext = TaskContext.get() taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1 // Fail the first attempts of a subset of the tasks @@ -276,28 +284,32 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex throw new Exception("Failing a task intentionally.") } iter - }.count() - val stageInfos = listener.getCompletedStageInfos - val taskInfos = listener.getCompletedTaskInfos - assert(stageInfos.size === 1) - assert(taskInfos.size === numPartitions + numFailedPartitions) - val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) - // We should not double count values in the merged accumulator - assert(stageAccum.value.toLong === numPartitions) - val taskAccumValues = taskInfos.flatMap { taskInfo => - if (!taskInfo.failed) { - // If a task succeeded, its update value should always be 1 - val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) - assert(taskAccum.update.isDefined) - assert(taskAccum.update.get.toLong === 1) - Some(taskAccum.value.toLong) - } else { - // If a task failed, we should not get its accumulator values - assert(taskInfo.accumulables.isEmpty) - None + } + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { _ => + val stageInfos = listener.getCompletedStageInfos + val taskInfos = listener.getCompletedTaskInfos + assert(stageInfos.size === 1) + assert(taskInfos.size === numPartitions + numFailedPartitions) + val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR) + // We should not double count values in the merged accumulator + assert(stageAccum.value.toLong === numPartitions) + val taskAccumValues = taskInfos.flatMap { taskInfo => + if (!taskInfo.failed) { + // If a task succeeded, its update value should always be 1 + val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR) + assert(taskAccum.update.isDefined) + assert(taskAccum.update.get.toLong === 1) + Some(taskAccum.value.toLong) + } else { + // If a task failed, we should not get its accumulator values + assert(taskInfo.accumulables.isEmpty) + None + } } + assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) } - assert(taskAccumValues.sorted === (1L to numPartitions).toSeq) + rdd.count() } } @@ -313,20 +325,27 @@ private[spark] object AccumulatorSuite { testName: String)(testBody: => Unit): Unit = { val listener = new SaveInfoListener sc.addSparkListener(listener) - // Verify that the accumulator does not already exist + // Register asserts in job completion callback to avoid flakiness + listener.registerJobCompletionCallback { jobId => + if (jobId == 0) { + // The first job is a dummy one to verify that the accumulator does not already exist + val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) + assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) + } else { + // In the subsequent jobs, verify that peak execution memory is updated + val accum = listener.getCompletedStageInfos + .flatMap(_.accumulables.values) + .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) + .getOrElse { + throw new TestFailedException( + s"peak execution memory accumulator not set in '$testName'", 0) + } + assert(accum.value.toLong > 0) + } + } + // Run the jobs sc.parallelize(1 to 10).count() - val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values) - assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)) testBody - // Verify that peak execution memory is updated - val accum = listener.getCompletedStageInfos - .flatMap(_.accumulables.values) - .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY) - .getOrElse { - throw new TestFailedException( - s"peak execution memory accumulator not set in '$testName'", 0) - } - assert(accum.value.toLong > 0) } } @@ -336,10 +355,22 @@ private[spark] object AccumulatorSuite { private class SaveInfoListener extends SparkListener { private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo] private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo] + private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq + /** Register a callback to be called on job end. */ + def registerJobCompletionCallback(callback: (Int => Unit)): Unit = { + jobCompletionCallback = callback + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (jobCompletionCallback != null) { + jobCompletionCallback(jobEnd.jobId) + } + } + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { completedStageInfos += stageCompleted.stageInfo } From 9407baa2a7c26f527f2d043715d313d75bd765bb Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 14 Aug 2015 13:44:38 -0700 Subject: [PATCH 0252/1824] [SPARK-9877] [CORE] Fix StandaloneRestServer NPE when submitting application Detailed exception log can be seen in [SPARK-9877](https://issues.apache.org/jira/browse/SPARK-9877), the problem is when creating `StandaloneRestServer`, `self` (`masterEndpoint`) is null. So this fix is creating `StandaloneRestServer` when `self` is available. Author: jerryshao Closes #8127 from jerryshao/SPARK-9877. --- .../org/apache/spark/deploy/master/Master.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9217202b69a6..26904d39a9be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -127,14 +127,8 @@ private[deploy] class Master( // Alternative application submission gateway that is stable across Spark versions private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) - private val restServer = - if (restServerEnabled) { - val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) - } else { - None - } - private val restServerBoundPort = restServer.map(_.start()) + private var restServer: Option[StandaloneRestServer] = None + private var restServerBoundPort: Option[Int] = None override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) @@ -148,6 +142,12 @@ private[deploy] class Master( } }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 6066) + restServer = Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) + } + restServerBoundPort = restServer.map(_.start()) + masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() applicationMetricsSystem.start() From 11ed2b180ec86523a94679a8b8132fadb911ccd5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Aug 2015 13:55:29 -0700 Subject: [PATCH 0253/1824] [SPARK-9978] [PYSPARK] [SQL] fix Window.orderBy and doc of ntile() Author: Davies Liu Closes #8213 from davies/fix_window. --- python/pyspark/sql/functions.py | 7 ++++--- python/pyspark/sql/tests.py | 23 +++++++++++++++++++++++ python/pyspark/sql/window.py | 2 +- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e98979533f90..41dfee9f54f7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -530,9 +530,10 @@ def lead(col, count=1, default=None): @since(1.4) def ntile(n): """ - Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in - a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will - get 2, the third row will get 3, and the fourth row will get 1... + Window function: returns the ntile group id (from 1 to `n` inclusive) + in an ordered window partition. Fow example, if `n` is 4, the first + quarter of the rows will get value 1, the second quarter will get 2, + the third quarter will get 3, and the last quarter will get 4. This is equivalent to the NTILE function in SQL. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 38c83c427a74..9b748101b5e5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1124,5 +1124,28 @@ def test_window_functions(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_window_functions_without_partitionBy(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.orderBy("key", df.value) + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.rowNumber().over(w), + F.rank().over(w), + F.denseRank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 4, 1, 1, 1, 1), + ("2", 1, 1, 1, 4, 2, 2, 2, 1), + ("2", 1, 2, 1, 4, 3, 2, 2, 2), + ("2", 2, 2, 2, 4, 4, 4, 3, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index c74745c726a0..eaf4d7e98620 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -64,7 +64,7 @@ def orderBy(*cols): Creates a :class:`WindowSpec` with the partitioning defined. """ sc = SparkContext._active_spark_context - jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols)) return WindowSpec(jspec) From 2a6590e510aba3bfc6603d280023128b3f5ac702 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 14 Aug 2015 14:05:03 -0700 Subject: [PATCH 0254/1824] [SPARK-9981] [ML] Made labels public for StringIndexerModel Also added unit test for integration between StringIndexerModel and IndexToString CC: holdenk We realized we should have left in your unit test (to catch the issue with removing the inverse() method), so this adds it back. mengxr Author: Joseph K. Bradley Closes #8211 from jkbradley/stridx-labels. --- .../spark/ml/feature/StringIndexer.scala | 5 ++++- .../spark/ml/feature/StringIndexerSuite.scala | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 63475780a6ff..24250e4c4cf9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -97,14 +97,17 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod /** * :: Experimental :: * Model fitted by [[StringIndexer]]. + * * NOTE: During transformation, if the input column does not exist, * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. + * + * @param labels Ordered list of labels, corresponding to indices to be assigned */ @Experimental class StringIndexerModel ( override val uid: String, - labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 0b4c8ba71ee6..05e05bdc64bb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -147,4 +147,22 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(actual === expected) } } + + test("StringIndexer, IndexToString are inverses") { + val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val idx2str = new IndexToString() + .setInputCol("labelIndex") + .setOutputCol("sameLabel") + .setLabels(indexer.labels) + idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { + case Row(a: String, b: String) => + assert(a === b) + } + } } From 1150a19b188a075166899fdb1e107b2ba1e505d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 14 Aug 2015 14:09:46 -0700 Subject: [PATCH 0255/1824] [SPARK-8670] [SQL] Nested columns can't be referenced in pyspark This bug is caused by a wrong column-exist-check in `__getitem__` of pyspark dataframe. `DataFrame.apply` accepts not only top level column names, but also nested column name like `a.b`, so we should remove that check from `__getitem__`. Author: Wenchen Fan Closes #8202 from cloud-fan/nested. --- python/pyspark/sql/dataframe.py | 2 -- python/pyspark/sql/tests.py | 4 +++- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 09647ff6d074..da742d7ce7d1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -722,8 +722,6 @@ def __getitem__(self, item): [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): - if item not in self.columns: - raise IndexError("no such column: %s" % item) jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9b748101b5e5..13cf647b66da 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -770,7 +770,7 @@ def test_access_column(self): self.assertTrue(isinstance(df['key'], Column)) self.assertTrue(isinstance(df[0], Column)) self.assertRaises(IndexError, lambda: df[2]) - self.assertRaises(IndexError, lambda: df["bad_key"]) + self.assertRaises(AnalysisException, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): @@ -794,7 +794,9 @@ def test_field_accessor(self): df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() self.assertEqual(1, df.select(df.l[0]).first()[0]) self.assertEqual(1, df.select(df.r["a"]).first()[0]) + self.assertEqual(1, df.select(df["r.a"]).first()[0]) self.assertEqual("b", df.select(df.r["b"]).first()[0]) + self.assertEqual("b", df.select(df["r.b"]).first()[0]) self.assertEqual("v", df.select(df.d["k"]).first()[0]) def test_infer_long_type(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index cf75e64e884b..fd0ead440119 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -634,6 +634,7 @@ class DataFrame private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. * @group dfops * @since 1.3.0 */ @@ -641,6 +642,7 @@ class DataFrame private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. + * Note that the column name can also reference to a nested column like `a.b`. * @group dfops * @since 1.3.0 */ From f3bfb711c1742d0915e43bda8230b4d1d22b4190 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 14 Aug 2015 15:10:01 -0700 Subject: [PATCH 0256/1824] [SPARK-9966] [STREAMING] Handle couple of corner cases in PIDRateEstimator 1. The rate estimator should not estimate any rate when there are no records in the batch, as there is no data to estimate the rate. In the current state, it estimates and set the rate to zero. That is incorrect. 2. The rate estimator should not never set the rate to zero under any circumstances. Otherwise the system will stop receiving data, and stop generating useful estimates (see reason 1). So the fix is to define a parameters that sets a lower bound on the estimated rate, so that the system always receives some data. Author: Tathagata Das Closes #8199 from tdas/SPARK-9966 and squashes the following commits: 829f793 [Tathagata Das] Fixed unit test and added comments 3a994db [Tathagata Das] Added min rate and updated tests in PIDRateEstimator --- .../scheduler/rate/PIDRateEstimator.scala | 46 ++++++++--- .../scheduler/rate/RateEstimator.scala | 4 +- .../rate/PIDRateEstimatorSuite.scala | 79 ++++++++++++------- 3 files changed, 87 insertions(+), 42 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala index 6ae56a68ad88..84a3ca9d74e5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.scheduler.rate +import org.apache.spark.Logging + /** * Implements a proportional-integral-derivative (PID) controller which acts on * the speed of ingestion of elements into Spark Streaming. A PID controller works @@ -26,7 +28,7 @@ package org.apache.spark.streaming.scheduler.rate * * @see https://en.wikipedia.org/wiki/PID_controller * - * @param batchDurationMillis the batch duration, in milliseconds + * @param batchIntervalMillis the batch duration, in milliseconds * @param proportional how much the correction should depend on the current * error. This term usually provides the bulk of correction and should be positive or zero. * A value too large would make the controller overshoot the setpoint, while a small value @@ -39,13 +41,17 @@ package org.apache.spark.streaming.scheduler.rate * of future errors, based on current rate of change. This value should be positive or 0. * This term is not used very often, as it impacts stability of the system. The default * value is 0. + * @param minRate what is the minimum rate that can be estimated. + * This must be greater than zero, so that the system always receives some data for rate + * estimation to work. */ private[streaming] class PIDRateEstimator( batchIntervalMillis: Long, - proportional: Double = 1D, - integral: Double = .2D, - derivative: Double = 0D) - extends RateEstimator { + proportional: Double, + integral: Double, + derivative: Double, + minRate: Double + ) extends RateEstimator with Logging { private var firstRun: Boolean = true private var latestTime: Long = -1L @@ -64,16 +70,23 @@ private[streaming] class PIDRateEstimator( require( derivative >= 0, s"Derivative term $derivative in PIDRateEstimator should be >= 0.") + require( + minRate > 0, + s"Minimum rate in PIDRateEstimator should be > 0") + logInfo(s"Created PIDRateEstimator with proportional = $proportional, integral = $integral, " + + s"derivative = $derivative, min rate = $minRate") - def compute(time: Long, // in milliseconds + def compute( + time: Long, // in milliseconds numElements: Long, processingDelay: Long, // in milliseconds schedulingDelay: Long // in milliseconds ): Option[Double] = { - + logTrace(s"\ntime = $time, # records = $numElements, " + + s"processing time = $processingDelay, scheduling delay = $schedulingDelay") this.synchronized { - if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) { + if (time > latestTime && numElements > 0 && processingDelay > 0) { // in seconds, should be close to batchDuration val delaySinceUpdate = (time - latestTime).toDouble / 1000 @@ -104,21 +117,30 @@ private[streaming] class PIDRateEstimator( val newRate = (latestRate - proportional * error - integral * historicalError - - derivative * dError).max(0.0) + derivative * dError).max(minRate) + logTrace(s""" + | latestRate = $latestRate, error = $error + | latestError = $latestError, historicalError = $historicalError + | delaySinceUpdate = $delaySinceUpdate, dError = $dError + """.stripMargin) + latestTime = time if (firstRun) { latestRate = processingRate latestError = 0D firstRun = false - + logTrace("First run, rate estimation skipped") None } else { latestRate = newRate latestError = error - + logTrace(s"New rate = $newRate") Some(newRate) } - } else None + } else { + logTrace("Rate estimation skipped") + None + } } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index 17ccebc1ed41..d7210f64fcc3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -18,7 +18,6 @@ package org.apache.spark.streaming.scheduler.rate import org.apache.spark.SparkConf -import org.apache.spark.SparkException import org.apache.spark.streaming.Duration /** @@ -61,7 +60,8 @@ object RateEstimator { val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0) val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2) val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0) - new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived) + val minRate = conf.getDouble("spark.streaming.backpressure.pid.minRate", 100) + new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate) case estimator => throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala index 97c32d8f2d59..a1af95be81c8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -36,72 +36,89 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { test("estimator checks ranges") { intercept[IllegalArgumentException] { - new PIDRateEstimator(0, 1, 2, 3) + new PIDRateEstimator(batchIntervalMillis = 0, 1, 2, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, -1, 2, 3) + new PIDRateEstimator(100, proportional = -1, 2, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, 0, -1, 3) + new PIDRateEstimator(100, 0, integral = -1, 3, 10) } intercept[IllegalArgumentException] { - new PIDRateEstimator(100, 0, 0, -1) + new PIDRateEstimator(100, 0, 0, derivative = -1, 10) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = 0) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, 0, minRate = -10) } } - private def createDefaultEstimator: PIDRateEstimator = { - new PIDRateEstimator(20, 1D, 0D, 0D) - } - - test("first bound is None") { - val p = createDefaultEstimator + test("first estimate is None") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) should equal(None) } - test("second bound is rate") { - val p = createDefaultEstimator + test("second estimate is not None") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) // 1000 elements / s p.compute(10, 10, 10, 0) should equal(Some(1000)) } - test("works even with no time between updates") { - val p = createDefaultEstimator + test("no estimate when no time difference between successive calls") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(time = 10, 10, 10, 0) shouldNot equal(None) + p.compute(time = 10, 10, 10, 0) should equal(None) + } + + test("no estimate when no records in previous batch") { + val p = createDefaultEstimator() p.compute(0, 10, 10, 0) - p.compute(10, 10, 10, 0) - p.compute(10, 10, 10, 0) should equal(None) + p.compute(10, numElements = 0, 10, 0) should equal(None) + p.compute(20, numElements = -10, 10, 0) should equal(None) } - test("bound is never negative") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + test("no estimate when there is no processing delay") { + val p = createDefaultEstimator() + p.compute(0, 10, 10, 0) + p.compute(10, 10, processingDelay = 0, 0) should equal(None) + p.compute(20, 10, processingDelay = -10, 0) should equal(None) + } + + test("estimate is never less than min rate") { + val minRate = 5D + val p = new PIDRateEstimator(20, 1D, 1D, 0D, minRate) // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing // this might point the estimator to try and decrease the bound, but we test it never - // goes below zero, which would be nonsensical. + // goes below the min rate, which would be nonsensical. val times = List.tabulate(50)(x => x * 20) // every 20ms - val elements = List.fill(50)(0) // no processing + val elements = List.fill(50)(1) // no processing val proc = List.fill(50)(20) // 20ms of processing val sched = List.fill(50)(100) // strictly positive accumulation val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) res.head should equal(None) - res.tail should equal(List.fill(49)(Some(0D))) + res.tail should equal(List.fill(49)(Some(minRate))) } test("with no accumulated or positive error, |I| > 0, follow the processing speed") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) // prepare a series of batch updates, one every 20ms with an increasing number of processed // elements in each batch, but constant processing time, and no accumulated error. Even though // the integral part is non-zero, the estimated rate should follow only the proportional term val times = List.tabulate(50)(x => x * 20) // every 20ms - val elements = List.tabulate(50)(x => x * 20) // increasing + val elements = List.tabulate(50)(x => (x + 1) * 20) // increasing val proc = List.fill(50)(20) // 20ms of processing val sched = List.fill(50)(0) val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) res.head should equal(None) - res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail) + res.tail should equal(List.tabulate(50)(x => Some((x + 1) * 1000D)).tail) } test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { - val p = new PIDRateEstimator(20, 1D, 1D, 0D) + val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) // prepare a series of batch updates, one every 20ms with an decreasing number of processed // elements in each batch, but constant processing time, and no accumulated error. Even though // the integral part is non-zero, the estimated rate should follow only the proportional term, @@ -116,13 +133,14 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { } test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") { - val p = new PIDRateEstimator(20, 1D, .01D, 0D) + val minRate = 10D + val p = new PIDRateEstimator(20, 1D, .01D, 0D, minRate) val times = List.tabulate(50)(x => x * 20) // every 20ms val rng = new Random() - val elements = List.tabulate(50)(x => rng.nextInt(1000)) + val elements = List.tabulate(50)(x => rng.nextInt(1000) + 1000) val procDelayMs = 20 val proc = List.fill(50)(procDelayMs) // 20ms of processing - val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait + val sched = List.tabulate(50)(x => rng.nextInt(19) + 1) // random wait val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) @@ -131,7 +149,12 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { res(n) should not be None if (res(n).get > 0 && sched(n) > 0) { res(n).get should be < speeds(n) + res(n).get should be >= minRate } } } + + private def createDefaultEstimator(): PIDRateEstimator = { + new PIDRateEstimator(20, 1D, 0D, 0D, 10) + } } From 18a761ef7a01a4dfa1dd91abe78cd68f2f8fdb67 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 14 Aug 2015 15:54:14 -0700 Subject: [PATCH 0257/1824] [SPARK-9968] [STREAMING] Reduced time spent within synchronized block to prevent lock starvation When the rate limiter is actually limiting the rate at which data is inserted into the buffer, the synchronized block of BlockGenerator.addData stays blocked for long time. This causes the thread switching the buffer and generating blocks (synchronized with addData) to starve and not generate blocks for seconds. The correct solution is to not block on the rate limiter within the synchronized block for adding data to the buffer. Author: Tathagata Das Closes #8204 from tdas/SPARK-9968 and squashes the following commits: 8cbcc1b [Tathagata Das] Removed unused val a73b645 [Tathagata Das] Reduced time spent within synchronized block --- .../streaming/receiver/BlockGenerator.scala | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 794dece370b2..300e820d0177 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -155,10 +155,17 @@ private[streaming] class BlockGenerator( /** * Push a single data item into the buffer. */ - def addData(data: Any): Unit = synchronized { + def addData(data: Any): Unit = { if (state == Active) { waitToPush() - currentBuffer += data + synchronized { + if (state == Active) { + currentBuffer += data + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } } else { throw new SparkException( "Cannot add data as BlockGenerator has not been started or has been stopped") @@ -169,11 +176,18 @@ private[streaming] class BlockGenerator( * Push a single data item into the buffer. After buffering the data, the * `BlockGeneratorListener.onAddData` callback will be called. */ - def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { + def addDataWithCallback(data: Any, metadata: Any): Unit = { if (state == Active) { waitToPush() - currentBuffer += data - listener.onAddData(data, metadata) + synchronized { + if (state == Active) { + currentBuffer += data + listener.onAddData(data, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } + } } else { throw new SparkException( "Cannot add data as BlockGenerator has not been started or has been stopped") @@ -185,13 +199,23 @@ private[streaming] class BlockGenerator( * `BlockGeneratorListener.onAddData` callback will be called. Note that all the data items * are atomically added to the buffer, and are hence guaranteed to be present in a single block. */ - def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = { if (state == Active) { + // Unroll iterator into a temp buffer, and wait for pushing in the process + val tempBuffer = new ArrayBuffer[Any] dataIterator.foreach { data => waitToPush() - currentBuffer += data + tempBuffer += data + } + synchronized { + if (state == Active) { + currentBuffer ++= tempBuffer + listener.onAddData(tempBuffer, metadata) + } else { + throw new SparkException( + "Cannot add data as BlockGenerator has not been started or has been stopped") + } } - listener.onAddData(dataIterator, metadata) } else { throw new SparkException( "Cannot add data as BlockGenerator has not been started or has been stopped") From 932b24fd144232fb08184f0bd0a46369ecba164e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 14 Aug 2015 17:35:17 -0700 Subject: [PATCH 0258/1824] [SPARK-9949] [SQL] Fix TakeOrderedAndProject's output. https://issues.apache.org/jira/browse/SPARK-9949 Author: Yin Huai Closes #8179 from yhuai/SPARK-9949. --- .../spark/sql/execution/basicOperators.scala | 12 ++++++++++- .../spark/sql/execution/PlannerSuite.scala | 20 ++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 247c900baae9..77b98064a9e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -237,7 +237,10 @@ case class TakeOrderedAndProject( projectList: Option[Seq[NamedExpression]], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } override def outputPartitioning: Partitioning = SinglePartition @@ -263,6 +266,13 @@ case class TakeOrderedAndProject( protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1) override def outputOrdering: Seq[SortOrder] = sortOrder + + override def simpleString: String = { + val orderByString = sortOrder.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + + s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 937a10854353..fad93b014c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -162,9 +162,23 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { } test("efficient limit -> project -> sort") { - val query = testData.sort('key).select('value).limit(2).logicalPlan - val planned = ctx.planner.TakeOrderedAndProject(query) - assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + { + val query = + testData.select('key, 'value).sort('key).limit(2).logicalPlan + val planned = ctx.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) + } + + { + // We need to make sure TakeOrderedAndProject's output is correct when we push a project + // into it. + val query = + testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan + val planned = ctx.planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) + } } test("PartitioningCollection") { From e5fd60415fbfea2c5c02602f7ddbc999dd058064 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Aug 2015 20:55:32 -0700 Subject: [PATCH 0259/1824] [SPARK-9934] Deprecate NIO ConnectionManager. Deprecate NIO ConnectionManager in Spark 1.5.0, before removing it in Spark 1.6.0. Author: Reynold Xin Closes #8162 from rxin/SPARK-9934. --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 2 ++ docs/configuration.md | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a796e7285019..0f1e2e069568 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -331,6 +331,8 @@ object SparkEnv extends Logging { case "netty" => new NettyBlockTransferService(conf, securityManager, numUsableCores) case "nio" => + logWarning("NIO-based block transfer service is deprecated, " + + "and will be removed in Spark 1.6.0.") new NioBlockTransferService(conf, securityManager) } diff --git a/docs/configuration.md b/docs/configuration.md index c60dd16839c0..32147098ae64 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -389,7 +389,8 @@ Apart from these, the following properties are also available, and may be useful Implementation to use for transferring shuffle and cached blocks between executors. There are two implementations available: netty and nio. Netty-based block transfer is intended to be simpler but equally efficient and is the default option - starting in 1.2. + starting in 1.2, and nio block transfer is deprecated in Spark 1.5.0 and will + be removed in Spark 1.6.0. From 37586e5449ff8f892d41f0b6b8fa1de83dd3849e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Aug 2015 20:56:55 -0700 Subject: [PATCH 0260/1824] [HOTFIX] fix duplicated braces Author: Davies Liu Closes #8219 from davies/fix_typo. --- .../main/scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../scala/org/apache/spark/storage/BlockManagerMaster.scala | 6 +++--- .../main/scala/org/apache/spark/util/ClosureCleaner.scala | 2 +- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- .../scala/org/apache/spark/examples/ml/MovieLensALS.scala | 2 +- .../apache/spark/examples/mllib/DecisionTreeRunner.scala | 2 +- .../org/apache/spark/examples/mllib/MovieLensALS.scala | 2 +- .../src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala | 2 +- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 2 +- .../org/apache/spark/sql/execution/ui/SQLListener.scala | 2 +- .../spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala | 2 +- .../apache/spark/streaming/scheduler/InputInfoTracker.scala | 2 +- 13 files changed, 15 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 86493673d958..eedb27942e84 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -222,7 +222,7 @@ private[spark] class BlockManager( return } catch { case e: Exception if i < MAX_ATTEMPTS => - logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) Thread.sleep(SLEEP_TIME_SECS * 1000) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f70f701494db..2a11f371b9d6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -103,7 +103,7 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) @@ -115,7 +115,7 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) @@ -129,7 +129,7 @@ class BlockManagerMaster( future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}", e) }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ebead830c646..150d82b3930e 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -181,7 +181,7 @@ private[spark] object ClosureCleaner extends Logging { return } - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++") + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") // A list of classes that represents closures enclosed in the given one val innerClasses = getInnerClosureClasses(func) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f2abf227dc12..fddc24dbfc23 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1366,7 +1366,7 @@ private[spark] object Utils extends Logging { file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) } sum += fileToLength(file) - logDebug(s"After processing file $file, string built is ${stringBuffer.toString}}") + logDebug(s"After processing file $file, string built is ${stringBuffer.toString}") } stringBuffer.toString } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index cd411397a4b9..3ae53e57dbdb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -76,7 +76,7 @@ object MovieLensALS { .text("path to a MovieLens dataset of movies") .action((x, c) => c.copy(movies = x)) opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("maxIter") .text(s"max number of iterations, default: ${defaultParams.maxIter}") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 57ffe3dd2524..cc6bce3cb7c9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -100,7 +100,7 @@ object DecisionTreeRunner { .action((x, c) => c.copy(numTrees = x)) opt[String]("featureSubsetStrategy") .text(s"feature subset sampling strategy" + - s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " + + s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}), " + s"default: ${defaultParams.featureSubsetStrategy}") .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index e43a6f2864c7..69691ae297f6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -55,7 +55,7 @@ object MovieLensALS { val parser = new OptionParser[Params]("MovieLensALS") { head("MovieLensALS: an example app for ALS on MovieLens data.") opt[Int]("rank") - .text(s"rank, default: ${defaultParams.rank}}") + .text(s"rank, default: ${defaultParams.rank}") .action((x, c) => c.copy(rank = x)) opt[Int]("numIterations") .text(s"number of iterations, default: ${defaultParams.numIterations}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9029093e0fa0..bbbcc8436b7c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -469,7 +469,7 @@ private[spark] object BLAS extends Serializable with Logging { require(A.numCols == x.size, s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") require(A.numRows == y.size, - s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 970f3c8282c8..8581d6b496c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -164,7 +164,7 @@ object HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + logDebug(s"Promoting $a to $newType in ${q.simpleString}") newType } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 039c13bf163c..8ee3b8bda8fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -170,7 +170,7 @@ object JdbcUtils extends Logging { case BinaryType => "BLOB" case TimestampType => "TIMESTAMP" case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" + case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") }) val nullable = if (field.nullable) "" else "NOT NULL" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 0b9bad987c48..5779c71f64e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -162,7 +162,7 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit // A task of an old stage attempt. Because a new stage is submitted, we can ignore it. } else if (stageAttemptID > stageMetrics.stageAttemptId) { logWarning(s"A task should not have a higher stageAttemptID ($stageAttemptID) then " + - s"what we have seen (${stageMetrics.stageAttemptId}})") + s"what we have seen (${stageMetrics.stageAttemptId})") } else { // TODO We don't know the attemptId. Currently, what we can do is overriding the // accumulator updates. However, if there are two same task are running, such as diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 31ce8e1ec14d..620b8a36a2ba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -84,7 +84,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( require( blockIds.length == walRecordHandles.length, s"Number of block Ids (${blockIds.length}) must be " + - s" same as number of WAL record handles (${walRecordHandles.length}})") + s" same as number of WAL record handles (${walRecordHandles.length})") require( isBlockIdValid.isEmpty || isBlockIdValid.length == blockIds.length, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 363c03d431f0..deb15d075975 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -66,7 +66,7 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { - throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + s"$batchTime is already added into InputInfoTracker, this is a illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo)) From ec29f2034a3306cc0afdc4c160b42c2eefa0897c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 14 Aug 2015 20:59:54 -0700 Subject: [PATCH 0261/1824] [SPARK-9634] [SPARK-9323] [SQL] cleanup unnecessary Aliases in LogicalPlan at the end of analysis Also alias the ExtractValue instead of wrapping it with UnresolvedAlias when resolve attribute in LogicalPlan, as this alias will be trimmed if it's unnecessary. Based on #7957 without the changes to mllib, but instead maintaining earlier behavior when using `withColumn` on expressions that already have metadata. Author: Wenchen Fan Author: Michael Armbrust Closes #8215 from marmbrus/pr/7957. --- .../sql/catalyst/analysis/Analyzer.scala | 75 ++++++++++++++++--- .../expressions/complexTypeCreator.scala | 2 - .../sql/catalyst/optimizer/Optimizer.scala | 9 ++- .../catalyst/plans/logical/LogicalPlan.scala | 8 +- .../plans/logical/basicOperators.scala | 2 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 17 +++++ .../scala/org/apache/spark/sql/Column.scala | 16 +++- .../spark/sql/ColumnExpressionSuite.scala | 9 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ 9 files changed, 120 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a684dbc3afa4..4bc1c1af40bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -82,7 +82,9 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic) + PullOutNondeterministic), + Batch("Cleanup", fixedPoint, + CleanupAliases) ) /** @@ -146,8 +148,6 @@ class Analyzer( child match { case _: UnresolvedAttribute => u case ne: NamedExpression => ne - case g: GetStructField => Alias(g, g.field.name)() - case g: GetArrayStructFields => Alias(g, g.field.name)() case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) case e if !e.resolved => u case other => Alias(other, s"_c$i")() @@ -384,9 +384,7 @@ class Analyzer( case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { - q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) - } + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -412,11 +410,6 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } - private def trimUnresolvedAlias(ne: NamedExpression) = ne match { - case UnresolvedAlias(child) => child - case other => other - } - private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { ordering.map { order => // Resolve SortOrder in one round. @@ -426,7 +419,7 @@ class Analyzer( try { val newOrder = order transformUp { case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + plan.resolve(nameParts, resolver).getOrElse(u) case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -968,3 +961,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] { case Subquery(_, child) => child } } + +/** + * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level + * expression in Project(project list) or Aggregate(aggregate expressions) or + * Window(window expressions). + */ +object CleanupAliases extends Rule[LogicalPlan] { + private def trimAliases(e: Expression): Expression = { + var stop = false + e.transformDown { + // CreateStruct is a special case, we need to retain its top level Aliases as they decide the + // name of StructField. We also need to stop transform down this expression, or the Aliases + // under CreateStruct will be mistakenly trimmed. + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } + + def trimNonTopLevelAliases(e: Expression): Expression = e match { + case a: Alias => + Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata) + case other => trimAliases(other) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Project(projectList, child) => + val cleanedProjectList = + projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Project(cleanedProjectList, child) + + case Aggregate(grouping, aggs, child) => + val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Aggregate(grouping.map(trimAliases), cleanedAggs, child) + + case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + val cleanedWindowExprs = + windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) + Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), + orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + + case other => + var stop = false + other transformExpressionsDown { + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4a071e663e0d..298aee349927 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { val fields = children.zipWithIndex.map { case (child, idx) => child match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ab5ac2c61e3..47b06cae1543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -260,8 +260,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] { val substitutedProjection = projectList1.map(_.transform { case a: Attribute => aliasMap.getOrElse(a, a) }).asInstanceOf[Seq[NamedExpression]] - - Project(substitutedProjection, child) + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + Project(cleanedProjection, child) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c290e6acb361..9bb466ac2d29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and wrap it with UnresolvedAlias which will be removed later. + // and aliased it with the last part of the name. // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as - // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => ExtractValue(expr, Literal(fieldName), resolver)) - Some(UnresolvedAlias(fieldExprs)) + Some(Alias(fieldExprs, nestedFields.last)()) // No matches. case Seq() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 7c404722d811..73b8261260ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -228,7 +228,7 @@ case class Window( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = - (projectList ++ windowExpressions).map(_.toAttribute) + projectList ++ windowExpressions.map(_.toAttribute) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c944bc69e25b..1e0cc81dae97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest { Project(testRelation.output :+ projected, testRelation))) checkAnalysis(plan, expected) } + + test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") { + val a = testRelation.output.head + var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col")) + var expected = testRelation.select((a + 1 + 2).as("col")) + checkAnalysis(plan, expected) + + plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col")) + expected = testRelation.groupBy(a)((min(a) + 1).as("col")) + checkAnalysis(plan, expected) + + // CreateStruct is a special case that we should not trim Alias for it. + plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 27bd08484734..807bc8c30c12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -753,10 +753,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as("colB")) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = Alias(expr, alias)() + def as(alias: String): Column = expr match { + case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias)() + } /** * (Scala-specific) Assigns the given aliases to the results of a table generating function. @@ -789,10 +795,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as('colB)) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = Alias(expr, alias.name)() + def as(alias: Symbol): Column = expr match { + case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias.name)() + } /** * Gives the column an alias with metadata. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index ee74e3e83da5..37738ec5b3c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} @@ -110,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { assert(df.select(df("a").alias("b")).columns.head === "b") } + test("as propagates metadata") { + val metadata = new MetadataBuilder + metadata.putString("key", "value") + val origCol = $"a".as("b", metadata.build()) + val newCol = origCol.as("c") + assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") + } + test("single explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 10bfa9b64f00..cf22797752b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -867,4 +867,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) assert(expected === actual) } + + test("SPARK-9323: DataFrame.orderBy should support nested column name") { + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1}}""" :: Nil)) + checkAnswer(df.orderBy("a.b"), Row(Row(1))) + } } From 6c4fdbec33af287d24cd0995ecbd7191545d05c9 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 14 Aug 2015 21:03:14 -0700 Subject: [PATCH 0262/1824] [SPARK-8887] [SQL] Explicit define which data types can be used as dynamic partition columns This PR enforce dynamic partition column data type requirements by adding analysis rules. JIRA: https://issues.apache.org/jira/browse/SPARK-8887 Author: Yijie Shen Closes #8201 from yjshen/dynamic_partition_columns. --- .../datasources/PartitioningUtils.scala | 13 +++++++++++++ .../datasources/ResolvedDataSource.scala | 5 ++++- .../execution/datasources/WriterContainer.scala | 2 +- .../spark/sql/execution/datasources/rules.scala | 8 ++++++-- .../sql/sources/hadoopFsRelationSuites.scala | 17 +++++++++++++++++ 5 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 66dfcc308cec..0a2007e15843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -26,6 +26,7 @@ import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -270,6 +271,18 @@ private[sql] object PartitioningUtils { private val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) + def validatePartitionColumnDataTypes( + schema: StructType, + partitionColumns: Array[String]): Unit = { + + ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns).foreach { field => + field.dataType match { + case _: AtomicType => // OK + case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column") + } + } + } + /** * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" * types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 7770bbd712f0..8fbaf3a3059d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -143,7 +143,7 @@ object ResolvedDataSource extends Logging { new ResolvedDataSource(clazz, relation) } - private def partitionColumnsSchema( + def partitionColumnsSchema( schema: StructType, partitionColumns: Array[String]): StructType = { StructType(partitionColumns.map { col => @@ -179,6 +179,9 @@ object ResolvedDataSource extends Logging { val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) path.makeQualified(fs.getUri, fs.getWorkingDirectory) } + + PartitioningUtils.validatePartitionColumnDataTypes(data.schema, partitionColumns) + val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name))) val r = dataSource.createRelation( sqlContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 2f11f4042240..d36197e50d44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -287,7 +287,7 @@ private[sql] class DynamicPartitionWriterContainer( PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) val str = If(IsNull(c), Literal(defaultPartitionName), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName + if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } // Returns the partition path given a partition key. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 40ca8bf4095d..9d3d35692ffc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -116,6 +116,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } + PartitioningUtils.validatePartitionColumnDataTypes(r.schema, part.keySet.toArray) + // Get all input data source relations of the query. val srcRelations = query.collect { case LogicalRelation(src: BaseRelation) => src @@ -138,10 +140,10 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableName, _, _, _, SaveMode.Overwrite, _, query) => + case CreateTableUsingAsSelect(tableName, _, _, partitionColumns, mode, _, query) => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (catalog.tableExists(Seq(tableName))) { + if (mode == SaveMode.Overwrite && catalog.tableExists(Seq(tableName))) { // Need to remove SubQuery operator. EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match { // Only do the check if the table is a data source table @@ -164,6 +166,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } + PartitioningUtils.validatePartitionColumnDataTypes(query.schema, partitionColumns) + case _ => // OK } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index af445626fbe4..8d0d9218ddd6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.sql.Date + import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration @@ -553,6 +555,21 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } + + test("SPARK-8887: Explicitly define which data types can be used as dynamic partition columns") { + val df = Seq( + (1, "v1", Array(1, 2, 3), Map("k1" -> "v1"), Tuple2(1, "4")), + (2, "v2", Array(4, 5, 6), Map("k2" -> "v2"), Tuple2(2, "5")), + (3, "v3", Array(7, 8, 9), Map("k3" -> "v3"), Tuple2(3, "6"))).toDF("a", "b", "c", "d", "e") + withTempDir { file => + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").save(file.getCanonicalPath) + } + } + intercept[AnalysisException] { + df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t") + } + } } // This class is used to test SPARK-8578. We should not use any custom output committer when From 609ce3c07d4962a9242e488ad0ed48c183896802 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Aug 2015 21:12:11 -0700 Subject: [PATCH 0263/1824] [SPARK-9984] [SQL] Create local physical operator interface. This pull request creates a new operator interface that is more similar to traditional database query iterators (with open/close/next/get). These local operators are not currently used anywhere, but will become the basis for SPARK-9983 (local physical operators for query execution). cc zsxwing Author: Reynold Xin Closes #8212 from rxin/SPARK-9984. --- .../sql/execution/local/FilterNode.scala | 47 ++++++++++ .../spark/sql/execution/local/LocalNode.scala | 86 +++++++++++++++++++ .../sql/execution/local/ProjectNode.scala | 42 +++++++++ .../sql/execution/local/SeqScanNode.scala | 49 +++++++++++ 4 files changed, 224 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala new file mode 100644 index 000000000000..a485a1a1d7ae --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -0,0 +1,47 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate + + +case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode { + + private[this] var predicate: (InternalRow) => Boolean = _ + + override def output: Seq[Attribute] = child.output + + override def open(): Unit = { + child.open() + predicate = GeneratePredicate.generate(condition, child.output) + } + + override def next(): Boolean = { + var found = false + while (child.next() && !found) { + found = predicate.apply(child.get()) + } + found + } + + override def get(): InternalRow = child.get() + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala new file mode 100644 index 000000000000..341c81438e6d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -0,0 +1,86 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.types.StructType + +/** + * A local physical operator, in the form of an iterator. + * + * Before consuming the iterator, open function must be called. + * After consuming the iterator, close function must be called. + */ +abstract class LocalNode extends TreeNode[LocalNode] { + + def output: Seq[Attribute] + + /** + * Initializes the iterator state. Must be called before calling `next()`. + * + * Implementations of this must also call the `open()` function of its children. + */ + def open(): Unit + + /** + * Advances the iterator to the next tuple. Returns true if there is at least one more tuple. + */ + def next(): Boolean + + /** + * Returns the current tuple. + */ + def get(): InternalRow + + /** + * Closes the iterator and releases all resources. + * + * Implementations of this must also call the `close()` function of its children. + */ + def close(): Unit + + /** + * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. + */ + def collect(): Seq[Row] = { + val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) + val result = new scala.collection.mutable.ArrayBuffer[Row] + open() + while (next()) { + result += converter.apply(get()).asInstanceOf[Row] + } + close() + result + } +} + + +abstract class LeafLocalNode extends LocalNode { + override def children: Seq[LocalNode] = Seq.empty +} + + +abstract class UnaryLocalNode extends LocalNode { + + def child: LocalNode + + override def children: Seq[LocalNode] = Seq(child) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala new file mode 100644 index 000000000000..e574d1473cdc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -0,0 +1,42 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} + + +case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) extends UnaryLocalNode { + + private[this] var project: UnsafeProjection = _ + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override def open(): Unit = { + project = UnsafeProjection.create(projectList, child.output) + child.open() + } + + override def next(): Boolean = child.next() + + override def get(): InternalRow = { + project.apply(child.get()) + } + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala new file mode 100644 index 000000000000..994de8afa9a0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +/** + * An operator that scans some local data collection in the form of Scala Seq. + */ +case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends LeafLocalNode { + + private[this] var iterator: Iterator[InternalRow] = _ + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + iterator = data.iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def get(): InternalRow = currentRow + + override def close(): Unit = { + // Do nothing + } +} From 71a3af8a94f900a26ac7094f22ec1216cab62e15 Mon Sep 17 00:00:00 2001 From: zc he Date: Fri, 14 Aug 2015 21:28:50 -0700 Subject: [PATCH 0264/1824] [SPARK-9960] [GRAPHX] sendMessage type fix in LabelPropagation.scala Author: zc he Closes #8188 from farseer90718/farseer-patch-1. --- .../scala/org/apache/spark/graphx/lib/LabelPropagation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 2bcf8684b8b8..a3ad6bed1c99 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -43,7 +43,7 @@ object LabelPropagation { */ def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } - def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, VertexId])] = { + def sendMessage(e: EdgeTriplet[VertexId, ED]): Iterator[(VertexId, Map[VertexId, Long])] = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) } def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long]) From 7c1e56825b716a7d703dff38254b4739755ac0c4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Aug 2015 22:30:35 -0700 Subject: [PATCH 0265/1824] [SPARK-9725] [SQL] fix serialization of UTF8String across different JVM The BYTE_ARRAY_OFFSET could be different in JVM with different configurations (for example, different heap size, 24 if heap > 32G, otherwise 16), so offset of UTF8String is not portable, we should handler that during serialization. Author: Davies Liu Closes #8210 from davies/serialize_utf8string. --- .../apache/spark/unsafe/types/UTF8String.java | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 667c00900f2c..cbcab958c05a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -18,8 +18,7 @@ package org.apache.spark.unsafe.types; import javax.annotation.Nonnull; -import java.io.Serializable; -import java.io.UnsupportedEncodingException; +import java.io.*; import java.nio.ByteOrder; import java.util.Arrays; import java.util.Map; @@ -38,12 +37,13 @@ *

    * Note: This is not designed for general use cases, should not be used outside SQL. */ -public final class UTF8String implements Comparable, Serializable { +public final class UTF8String implements Comparable, Externalizable { + // These are only updated by readExternal() @Nonnull - private final Object base; - private final long offset; - private final int numBytes; + private Object base; + private long offset; + private int numBytes; public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } @@ -127,6 +127,11 @@ protected UTF8String(Object base, long offset, int numBytes) { this.numBytes = numBytes; } + // for serialization + public UTF8String() { + this(null, 0, 0); + } + /** * Writes the content of this string into a memory address, identified by an object and an offset. * The target memory address must already been allocated, and have enough space to hold all the @@ -978,4 +983,18 @@ public UTF8String soundex() { } return UTF8String.fromBytes(sx); } + + public void writeExternal(ObjectOutput out) throws IOException { + byte[] bytes = getBytes(); + out.writeInt(bytes.length); + out.write(bytes); + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + offset = BYTE_ARRAY_OFFSET; + numBytes = in.readInt(); + base = new byte[numBytes]; + in.readFully((byte[]) base); + } + } From a85fb6c07fdda5c74d53d6373910dcf5db3ff111 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 15 Aug 2015 10:46:04 +0100 Subject: [PATCH 0266/1824] [SPARK-9980] [BUILD] Fix SBT publishLocal error due to invalid characters in doc Tiny modification to a few comments ```sbt publishLocal``` work again. Author: Herman van Hovell Closes #8209 from hvanhovell/SPARK-9980. --- .../java/org/apache/spark/unsafe/map/BytesToBytesMap.java | 6 +++--- .../apache/spark/examples/ml/JavaDeveloperApiExample.java | 4 ++-- .../examples/streaming/JavaStatefulNetworkWordCount.java | 2 +- .../src/main/java/org/apache/spark/launcher/Main.java | 4 ++-- .../apache/spark/launcher/SparkClassCommandBuilder.java | 2 +- .../java/org/apache/spark/launcher/SparkLauncher.java | 6 +++--- .../apache/spark/launcher/SparkSubmitCommandBuilder.java | 4 ++-- .../apache/spark/launcher/SparkSubmitOptionParser.java | 8 ++++---- .../org/apache/spark/unsafe/memory/TaskMemoryManager.java | 2 +- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f3a4fcf4d58..b24eed3952fd 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -92,9 +92,9 @@ public final class BytesToBytesMap { /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be - * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since - * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array - * entries per key, giving us a maximum capacity of (1 << 29). + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, + * since that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). */ @VisibleForTesting static final int MAX_CAPACITY = (1 << 29); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 3f1fe900b000..a377694507d2 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -124,7 +124,7 @@ public String uid() { /** * Param for max number of iterations - *

    + *

    * NOTE: The usual way to add a parameter to a model or algorithm is to include: * - val myParamName: ParamType * - def getMyParamName @@ -222,7 +222,7 @@ public Vector predictRaw(Vector features) { /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. - *

    + *

    * This is used for the defaul implementation of [[transform()]]. * * In Java, we have to make this method public since Java does not understand Scala's protected diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 02f58f48b07a..99b63a2590ae 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -45,7 +45,7 @@ * Usage: JavaStatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. - *

    + *

    * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 62492f9baf3b..a4e3acc674f3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -32,7 +32,7 @@ class Main { /** * Usage: Main [class] [class args] - *

    + *

    * This CLI works in two different modes: *

    - * At last, if there are more than `preferredNumExecutors` idle executors (weight = 0), - * returns all idle executors. Otherwise, we only return `preferredNumExecutors` best options - * according to the weights. + * At last, if there are any idle executors (weight = 0), returns all idle executors. + * Otherwise, returns the executors that have the minimum weight. * * * @@ -134,8 +162,7 @@ private[streaming] class ReceiverSchedulingPolicy { receiverId: Int, preferredLocation: Option[String], receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo], - executors: Seq[String], - preferredNumExecutors: Int = 3): Seq[String] = { + executors: Seq[String]): Seq[String] = { if (executors.isEmpty) { return Seq.empty } @@ -156,15 +183,18 @@ private[streaming] class ReceiverSchedulingPolicy { } }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor - val idleExecutors = (executors.toSet -- executorWeights.keys).toSeq - if (idleExecutors.size >= preferredNumExecutors) { - // If there are more than `preferredNumExecutors` idle executors, return all of them + val idleExecutors = executors.toSet -- executorWeights.keys + if (idleExecutors.nonEmpty) { scheduledExecutors ++= idleExecutors } else { - // If there are less than `preferredNumExecutors` idle executors, return 3 best options - scheduledExecutors ++= idleExecutors - val sortedExecutors = executorWeights.toSeq.sortBy(_._2).map(_._1) - scheduledExecutors ++= (idleExecutors ++ sortedExecutors).take(preferredNumExecutors) + // There is no idle executor. So select all executors that have the minimum weight. + val sortedExecutors = executorWeights.toSeq.sortBy(_._2) + if (sortedExecutors.nonEmpty) { + val minWeight = sortedExecutors(0)._2 + scheduledExecutors ++= sortedExecutors.takeWhile(_._2 == minWeight).map(_._1) + } else { + // This should not happen since "executors" is not empty + } } scheduledExecutors.toSeq } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 30d25a64e307..3d532a675db0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -244,8 +244,21 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } if (isTrackerStopping || isTrackerStopped) { - false - } else if (!scheduleReceiver(streamId).contains(hostPort)) { + return false + } + + val scheduledExecutors = receiverTrackingInfos(streamId).scheduledExecutors + val accetableExecutors = if (scheduledExecutors.nonEmpty) { + // This receiver is registering and it's scheduled by + // ReceiverSchedulingPolicy.scheduleReceivers. So use "scheduledExecutors" to check it. + scheduledExecutors.get + } else { + // This receiver is scheduled by "ReceiverSchedulingPolicy.rescheduleReceiver", so calling + // "ReceiverSchedulingPolicy.rescheduleReceiver" again to check it. + scheduleReceiver(streamId) + } + + if (!accetableExecutors.contains(hostPort)) { // Refuse it since it's scheduled to a wrong executor false } else { @@ -426,12 +439,25 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false startReceiver(receiver, executors) } case RestartReceiver(receiver) => - val scheduledExecutors = schedulingPolicy.rescheduleReceiver( - receiver.streamId, - receiver.preferredLocation, - receiverTrackingInfos, - getExecutors) - updateReceiverScheduledExecutors(receiver.streamId, scheduledExecutors) + // Old scheduled executors minus the ones that are not active any more + val oldScheduledExecutors = getStoredScheduledExecutors(receiver.streamId) + val scheduledExecutors = if (oldScheduledExecutors.nonEmpty) { + // Try global scheduling again + oldScheduledExecutors + } else { + val oldReceiverInfo = receiverTrackingInfos(receiver.streamId) + // Clear "scheduledExecutors" to indicate we are going to do local scheduling + val newReceiverInfo = oldReceiverInfo.copy( + state = ReceiverState.INACTIVE, scheduledExecutors = None) + receiverTrackingInfos(receiver.streamId) = newReceiverInfo + schedulingPolicy.rescheduleReceiver( + receiver.streamId, + receiver.preferredLocation, + receiverTrackingInfos, + getExecutors) + } + // Assume there is one receiver restarting at one time, so we don't need to update + // receiverTrackingInfos startReceiver(receiver, scheduledExecutors) case c: CleanupOldBlocks => receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c)) @@ -464,6 +490,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false context.reply(true) } + /** + * Return the stored scheduled executors that are still alive. + */ + private def getStoredScheduledExecutors(receiverId: Int): Seq[String] = { + if (receiverTrackingInfos.contains(receiverId)) { + val scheduledExecutors = receiverTrackingInfos(receiverId).scheduledExecutors + if (scheduledExecutors.nonEmpty) { + val executors = getExecutors.toSet + // Only return the alive executors + scheduledExecutors.get.filter(executors) + } else { + Nil + } + } else { + Nil + } + } + /** * Start a receiver along with its scheduled executors */ @@ -484,7 +528,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node - val startReceiverFunc = new StartReceiverFunc(checkpointDirOption, serializableHadoopConf) + val startReceiverFunc: Iterator[Receiver[_]] => Unit = + (iterator: Iterator[Receiver[_]]) => { + if (!iterator.hasNext) { + throw new SparkException( + "Could not start receiver as object not found.") + } + if (TaskContext.get().attemptNumber() == 0) { + val receiver = iterator.next() + assert(iterator.hasNext == false) + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() + } else { + // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. + } + } // Create the RDD using the scheduledExecutors to run the receiver in a Spark job val receiverRDD: RDD[Receiver[_]] = @@ -541,31 +601,3 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - -/** - * Function to start the receiver on the worker node. Use a class instead of closure to avoid - * the serialization issue. - */ -private[streaming] class StartReceiverFunc( - checkpointDirOption: Option[String], - serializableHadoopConf: SerializableConfiguration) - extends (Iterator[Receiver[_]] => Unit) with Serializable { - - override def apply(iterator: Iterator[Receiver[_]]): Unit = { - if (!iterator.hasNext) { - throw new SparkException( - "Could not start receiver as object not found.") - } - if (TaskContext.get().attemptNumber() == 0) { - val receiver = iterator.next() - assert(iterator.hasNext == false) - val supervisor = new ReceiverSupervisorImpl( - receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) - supervisor.start() - supervisor.awaitTermination() - } else { - // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. - } - } - -} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala index 0418d776ecc9..b2a51d72bac2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -39,7 +39,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { assert(scheduledExecutors.toSet === Set("host1", "host2")) } - test("rescheduleReceiver: return all idle executors if more than 3 idle executors") { + test("rescheduleReceiver: return all idle executors if there are any idle executors") { val executors = Seq("host1", "host2", "host3", "host4", "host5") // host3 is idle val receiverTrackingInfoMap = Map( @@ -49,16 +49,16 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5")) } - test("rescheduleReceiver: return 3 best options if less than 3 idle executors") { + test("rescheduleReceiver: return all executors that have minimum weight if no idle executors") { val executors = Seq("host1", "host2", "host3", "host4", "host5") - // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0 - // host4 and host5 are idle + // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0, host4 = 0.5, host5 = 0.5 val receiverTrackingInfoMap = Map( 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")), 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None), - 2 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None)) + 2 -> ReceiverTrackingInfo(2, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None), + 3 -> ReceiverTrackingInfo(4, ReceiverState.SCHEDULED, Some(Seq("host4", "host5")), None)) val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( - 3, None, receiverTrackingInfoMap, executors) + 4, None, receiverTrackingInfoMap, executors) assert(scheduledExecutors.toSet === Set("host2", "host4", "host5")) } @@ -127,4 +127,5 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { assert(executors.isEmpty) } } + } From df7041d02d3fd44b08a859f5d77bf6fb726895f0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 24 Aug 2015 23:38:32 -0700 Subject: [PATCH 0398/1824] [SPARK-10196] [SQL] Correctly saving decimals in internal rows to JSON. https://issues.apache.org/jira/browse/SPARK-10196 Author: Yin Huai Closes #8408 from yhuai/DecimalJsonSPARK-10196. --- .../datasources/json/JacksonGenerator.scala | 2 +- .../sources/JsonHadoopFsRelationSuite.scala | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 99ac7730bd1c..330ba907b2ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -95,7 +95,7 @@ private[sql] object JacksonGenerator { case (FloatType, v: Float) => gen.writeNumber(v) case (DoubleType, v: Double) => gen.writeNumber(v) case (LongType, v: Long) => gen.writeNumber(v) - case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) + case (DecimalType(), v: Decimal) => gen.writeNumber(v.toJavaBigDecimal) case (ByteType, v: Byte) => gen.writeNumber(v.toInt) case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) case (BooleanType, v: Boolean) => gen.writeBoolean(v) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index ed6d512ab36f..8ca3a1708519 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.math.BigDecimal + import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -75,4 +77,29 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } + + test("SPARK-10196: save decimal type to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("decimal", DecimalType(7, 2)) + + val data = + Row(new BigDecimal("10.02")) :: + Row(new BigDecimal("20000.99")) :: + Row(new BigDecimal("10000")) :: Nil + val df = createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } } From bf03fe68d62f33dda70dff45c3bda1f57b032dfc Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 25 Aug 2015 14:58:42 +0800 Subject: [PATCH 0399/1824] [SPARK-10136] [SQL] A more robust fix for SPARK-10136 PR #8341 is a valid fix for SPARK-10136, but it didn't catch the real root cause. The real problem can be rather tricky to explain, and requires audiences to be pretty familiar with parquet-format spec, especially details of `LIST` backwards-compatibility rules. Let me have a try to give an explanation here. The structure of the problematic Parquet schema generated by parquet-avro is something like this: ``` message m { group f (LIST) { // Level 1 repeated group array (LIST) { // Level 2 repeated array; // Level 3 } } } ``` (The schema generated by parquet-thrift is structurally similar, just replace the `array` at level 2 with `f_tuple`, and the other one at level 3 with `f_tuple_tuple`.) This structure consists of two nested legacy 2-level `LIST`-like structures: 1. The repeated group type at level 2 is the element type of the outer array defined at level 1 This group should map to an `CatalystArrayConverter.ElementConverter` when building converters. 2. The repeated primitive type at level 3 is the element type of the inner array defined at level 2 This group should also map to an `CatalystArrayConverter.ElementConverter`. The root cause of SPARK-10136 is that, the group at level 2 isn't properly recognized as the element type of level 1. Thus, according to parquet-format spec, the repeated primitive at level 3 is left as a so called "unannotated repeated primitive type", and is recognized as a required list of required primitive type, thus a `RepeatedPrimitiveConverter` instead of a `CatalystArrayConverter.ElementConverter` is created for it. According to parquet-format spec, unannotated repeated type shouldn't appear in a `LIST`- or `MAP`-annotated group. PR #8341 fixed this issue by allowing such unannotated repeated type appear in `LIST`-annotated groups, which is a non-standard, hacky, but valid fix. (I didn't realize this when authoring #8341 though.) As for the reason why level 2 isn't recognized as a list element type, it's because of the following `LIST` backwards-compatibility rule defined in the parquet-format spec: > If the repeated field is a group with one field and is named either `array` or uses the `LIST`-annotated group's name with `_tuple` appended then the repeated type is the element type and elements are required. (The `array` part is for parquet-avro compatibility, while the `_tuple` part is for parquet-thrift.) This rule is implemented in [`CatalystSchemaConverter.isElementType`] [1], but neglected in [`CatalystRowConverter.isElementType`] [2]. This PR delivers a more robust fix by adding this rule in the latter method. Note that parquet-avro 1.7.0 also suffers from this issue. Details can be found at [PARQUET-364] [3]. [1]: https://github.com/apache/spark/blob/85f9a61357994da5023b08b0a8a2eb09388ce7f8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala#L259-L305 [2]: https://github.com/apache/spark/blob/85f9a61357994da5023b08b0a8a2eb09388ce7f8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala#L456-L463 [3]: https://issues.apache.org/jira/browse/PARQUET-364 Author: Cheng Lian Closes #8361 from liancheng/spark-10136/proper-version. --- .../parquet/CatalystRowConverter.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index d2c2db51769b..cbf0704c4a9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -415,8 +415,9 @@ private[parquet] class CatalystRowConverter( private val elementConverter: Converter = { val repeatedType = parquetSchema.getType(0) val elementType = catalystSchema.elementType + val parentName = parquetSchema.getName - if (isElementType(repeatedType, elementType)) { + if (isElementType(repeatedType, elementType, parentName)) { newConverter(repeatedType, elementType, new ParentContainerUpdater { override def set(value: Any): Unit = currentArray += value }) @@ -453,10 +454,13 @@ private[parquet] class CatalystRowConverter( * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules */ // scalastyle:on - private def isElementType(parquetRepeatedType: Type, catalystElementType: DataType): Boolean = { + private def isElementType( + parquetRepeatedType: Type, catalystElementType: DataType, parentName: String): Boolean = { (parquetRepeatedType, catalystElementType) match { case (t: PrimitiveType, _) => true case (t: GroupType, _) if t.getFieldCount > 1 => true + case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == "array" => true + case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == parentName + "_tuple" => true case (t: GroupType, StructType(Array(f))) if f.name == t.getFieldName(0) => true case _ => false } @@ -474,15 +478,9 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = converter - override def end(): Unit = { - converter.updater.end() - currentArray += currentElement - } + override def end(): Unit = currentArray += currentElement - override def start(): Unit = { - converter.updater.start() - currentElement = null - } + override def start(): Unit = currentElement = null } } From 82268f07abfa658869df2354ae72f8d6ddd119e8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 25 Aug 2015 00:04:10 -0700 Subject: [PATCH 0400/1824] [SPARK-9293] [SPARK-9813] Analysis should check that set operations are only performed on tables with equal numbers of columns This patch adds an analyzer rule to ensure that set operations (union, intersect, and except) are only applied to tables with the same number of columns. Without this rule, there are scenarios where invalid queries can return incorrect results instead of failing with error messages; SPARK-9813 provides one example of this problem. In other cases, the invalid query can crash at runtime with extremely confusing exceptions. I also performed a bit of cleanup to refactor some of those logical operators' code into a common `SetOperation` base class. Author: Josh Rosen Closes #7631 from JoshRosen/SPARK-9293. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 +++ .../catalyst/analysis/HiveTypeCoercion.scala | 14 +++---- .../plans/logical/basicOperators.scala | 38 +++++++++---------- .../analysis/AnalysisErrorSuite.scala | 18 +++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- 6 files changed, 48 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 39f554c137c9..7701fd045104 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -137,6 +137,12 @@ trait CheckAnalysis { } } + case s @ SetOperation(left, right) if left.output.length != right.output.length => + failAnalysis( + s"${s.nodeName} can only be performed on tables with the same number of columns, " + + s"but the left table has ${left.output.length} columns and the right has " + + s"${right.output.length}") + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2cb067f4aac9..a1aa2a2b2c68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -203,6 +203,7 @@ object HiveTypeCoercion { planName: String, left: LogicalPlan, right: LogicalPlan): (LogicalPlan, LogicalPlan) = { + require(left.output.length == right.output.length) val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => @@ -229,15 +230,10 @@ object HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if p.analyzed => p - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right) - Union(newLeft, newRight) - case e @ Except(left, right) if e.childrenResolved && !e.resolved => - val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right) - Except(newLeft, newRight) - case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => - val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right) - Intersect(newLeft, newRight) + case s @ SetOperation(left, right) if s.childrenResolved + && left.output.length == right.output.length && !s.resolved => + val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) + s.makeCopy(Array(newLeft, newRight)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 73b8261260ac..722f69cdca82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -89,13 +89,21 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - override def output: Seq[Attribute] = left.output + final override def output: Seq[Attribute] = left.output - override lazy val resolved: Boolean = + final override lazy val resolved: Boolean = childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } +} + +private[sql] object SetOperation { + def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) +} + +case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes @@ -103,6 +111,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { } } +case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) + +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) + case class Join( left: LogicalPlan, right: LogicalPlan, @@ -142,15 +154,6 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } - -case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output: Seq[Attribute] = left.output - - override lazy val resolved: Boolean = - childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} - case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], @@ -160,7 +163,7 @@ case class InsertIntoTable( extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = Seq.empty assert(overwrite || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { @@ -440,10 +443,3 @@ case object OneRowRelation extends LeafNode { override def statistics: Statistics = Statistics(sizeInBytes = 1) } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output: Seq[Attribute] = left.output - - override lazy val resolved: Boolean = - childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 7065adce04bf..fbdd3a7776f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -145,6 +145,24 @@ class AnalysisErrorSuite extends AnalysisTest { UnresolvedTestPlan(), "unresolved" :: Nil) + errorTest( + "union with unequal number of columns", + testRelation.unionAll(testRelation2), + "union" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + + errorTest( + "intersect with unequal number of columns", + testRelation.intersect(testRelation2), + "intersect" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + + errorTest( + "except with unequal number of columns", + testRelation.except(testRelation2), + "except" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + errorTest( "SPARK-9955: correct error message for aggregate", // When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index bbe8c1911bf8..98d21aa76d64 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -751,7 +751,7 @@ private[hive] case class InsertIntoHiveTable( extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = Seq.empty val numDynamicPartitions = partition.values.count(_.isEmpty) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 12c667e6e92d..62efda613a17 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -61,7 +61,7 @@ case class InsertIntoHiveTable( serializer } - def output: Seq[Attribute] = child.output + def output: Seq[Attribute] = Seq.empty def saveAsHiveFile( rdd: RDD[InternalRow], From d4549fe58fa0d781e0e891bceff893420cb1d598 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Tue, 25 Aug 2015 00:28:51 -0700 Subject: [PATCH 0401/1824] [SPARK-10214] [SPARKR] [DOCS] Improve SparkR Column, DataFrame API docs cc: shivaram ## Summary - Add name tags to each methods in DataFrame.R and column.R - Replace `rdname column` with `rdname {each_func}`. i.e. alias method : `rdname column` => `rdname alias` ## Generated PDF File https://drive.google.com/file/d/0B9biIZIU47lLNHN2aFpnQXlSeGs/view?usp=sharing ## JIRA [[SPARK-10214] Improve SparkR Column, DataFrame API docs - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-10214) Author: Yu ISHIKAWA Closes #8414 from yu-iskw/SPARK-10214. --- R/pkg/R/DataFrame.R | 101 +++++++++++++++++++++++++++++++++++--------- R/pkg/R/column.R | 40 ++++++++++++------ R/pkg/R/generics.R | 2 +- 3 files changed, 109 insertions(+), 34 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 895603235011..10f3c4ea5986 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -27,9 +27,10 @@ setOldClass("jobj") #' \code{jsonFile}, \code{table} etc. #' @rdname DataFrame #' @seealso jsonFile, table +#' @docType class #' -#' @param env An R environment that stores bookkeeping states of the DataFrame -#' @param sdf A Java object reference to the backing Scala DataFrame +#' @slot env An R environment that stores bookkeeping states of the DataFrame +#' @slot sdf A Java object reference to the backing Scala DataFrame #' @export setClass("DataFrame", slots = list(env = "environment", @@ -61,6 +62,7 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @param x A SparkSQL DataFrame #' #' @rdname printSchema +#' @name printSchema #' @export #' @examples #'\dontrun{ @@ -84,6 +86,7 @@ setMethod("printSchema", #' @param x A SparkSQL DataFrame #' #' @rdname schema +#' @name schema #' @export #' @examples #'\dontrun{ @@ -106,6 +109,7 @@ setMethod("schema", #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. #' @rdname explain +#' @name explain #' @export #' @examples #'\dontrun{ @@ -135,6 +139,7 @@ setMethod("explain", #' @param x A SparkSQL DataFrame #' #' @rdname isLocal +#' @name isLocal #' @export #' @examples #'\dontrun{ @@ -158,6 +163,7 @@ setMethod("isLocal", #' @param numRows The number of rows to print. Defaults to 20. #' #' @rdname showDF +#' @name showDF #' @export #' @examples #'\dontrun{ @@ -181,6 +187,7 @@ setMethod("showDF", #' @param x A SparkSQL DataFrame #' #' @rdname show +#' @name show #' @export #' @examples #'\dontrun{ @@ -206,6 +213,7 @@ setMethod("show", "DataFrame", #' @param x A SparkSQL DataFrame #' #' @rdname dtypes +#' @name dtypes #' @export #' @examples #'\dontrun{ @@ -230,6 +238,8 @@ setMethod("dtypes", #' @param x A SparkSQL DataFrame #' #' @rdname columns +#' @name columns +#' @aliases names #' @export #' @examples #'\dontrun{ @@ -248,7 +258,7 @@ setMethod("columns", }) #' @rdname columns -#' @aliases names,DataFrame,function-method +#' @name names setMethod("names", signature(x = "DataFrame"), function(x) { @@ -256,6 +266,7 @@ setMethod("names", }) #' @rdname columns +#' @name names<- setMethod("names<-", signature(x = "DataFrame"), function(x, value) { @@ -273,6 +284,7 @@ setMethod("names<-", #' @param tableName A character vector containing the name of the table #' #' @rdname registerTempTable +#' @name registerTempTable #' @export #' @examples #'\dontrun{ @@ -299,6 +311,7 @@ setMethod("registerTempTable", #' the existing rows in the table. #' #' @rdname insertInto +#' @name insertInto #' @export #' @examples #'\dontrun{ @@ -321,7 +334,8 @@ setMethod("insertInto", #' #' @param x A SparkSQL DataFrame #' -#' @rdname cache-methods +#' @rdname cache +#' @name cache #' @export #' @examples #'\dontrun{ @@ -347,6 +361,7 @@ setMethod("cache", #' #' @param x The DataFrame to persist #' @rdname persist +#' @name persist #' @export #' @examples #'\dontrun{ @@ -372,6 +387,7 @@ setMethod("persist", #' @param x The DataFrame to unpersist #' @param blocking Whether to block until all blocks are deleted #' @rdname unpersist-methods +#' @name unpersist #' @export #' @examples #'\dontrun{ @@ -397,6 +413,7 @@ setMethod("unpersist", #' @param x A SparkSQL DataFrame #' @param numPartitions The number of partitions to use. #' @rdname repartition +#' @name repartition #' @export #' @examples #'\dontrun{ @@ -446,6 +463,7 @@ setMethod("toJSON", #' @param x A SparkSQL DataFrame #' @param path The directory where the file is saved #' @rdname saveAsParquetFile +#' @name saveAsParquetFile #' @export #' @examples #'\dontrun{ @@ -467,6 +485,7 @@ setMethod("saveAsParquetFile", #' #' @param x A SparkSQL DataFrame #' @rdname distinct +#' @name distinct #' @export #' @examples #'\dontrun{ @@ -488,7 +507,8 @@ setMethod("distinct", #' @description Returns a new DataFrame containing distinct rows in this DataFrame #' #' @rdname unique -#' @aliases unique +#' @name unique +#' @aliases distinct setMethod("unique", signature(x = "DataFrame"), function(x) { @@ -526,7 +546,7 @@ setMethod("sample", }) #' @rdname sample -#' @aliases sample +#' @name sample_frac setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), @@ -541,6 +561,8 @@ setMethod("sample_frac", #' @param x A SparkSQL DataFrame #' #' @rdname count +#' @name count +#' @aliases nrow #' @export #' @examples #'\dontrun{ @@ -574,6 +596,7 @@ setMethod("nrow", #' @param x a SparkSQL DataFrame #' #' @rdname ncol +#' @name ncol #' @export #' @examples #'\dontrun{ @@ -593,6 +616,7 @@ setMethod("ncol", #' @param x a SparkSQL DataFrame #' #' @rdname dim +#' @name dim #' @export #' @examples #'\dontrun{ @@ -613,8 +637,8 @@ setMethod("dim", #' @param x A SparkSQL DataFrame #' @param stringsAsFactors (Optional) A logical indicating whether or not string columns #' should be converted to factors. FALSE by default. - -#' @rdname collect-methods +#' @rdname collect +#' @name collect #' @export #' @examples #'\dontrun{ @@ -650,6 +674,7 @@ setMethod("collect", #' @return A new DataFrame containing the number of rows specified. #' #' @rdname limit +#' @name limit #' @export #' @examples #' \dontrun{ @@ -669,6 +694,7 @@ setMethod("limit", #' Take the first NUM rows of a DataFrame and return a the results as a data.frame #' #' @rdname take +#' @name take #' @export #' @examples #'\dontrun{ @@ -696,6 +722,7 @@ setMethod("take", #' @return A data.frame #' #' @rdname head +#' @name head #' @export #' @examples #'\dontrun{ @@ -717,6 +744,7 @@ setMethod("head", #' @param x A SparkSQL DataFrame #' #' @rdname first +#' @name first #' @export #' @examples #'\dontrun{ @@ -732,7 +760,7 @@ setMethod("first", take(x, 1) }) -# toRDD() +# toRDD # # Converts a Spark DataFrame to an RDD while preserving column names. # @@ -769,6 +797,7 @@ setMethod("toRDD", #' @seealso GroupedData #' @aliases group_by #' @rdname groupBy +#' @name groupBy #' @export #' @examples #' \dontrun{ @@ -792,7 +821,7 @@ setMethod("groupBy", }) #' @rdname groupBy -#' @aliases group_by +#' @name group_by setMethod("group_by", signature(x = "DataFrame"), function(x, ...) { @@ -804,7 +833,8 @@ setMethod("group_by", #' Compute aggregates by specifying a list of columns #' #' @param x a DataFrame -#' @rdname DataFrame +#' @rdname agg +#' @name agg #' @aliases summarize #' @export setMethod("agg", @@ -813,8 +843,8 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @rdname DataFrame -#' @aliases agg +#' @rdname agg +#' @name summarize setMethod("summarize", signature(x = "DataFrame"), function(x, ...) { @@ -890,12 +920,14 @@ getColumn <- function(x, c) { } #' @rdname select +#' @name $ setMethod("$", signature(x = "DataFrame"), function(x, name) { getColumn(x, name) }) #' @rdname select +#' @name $<- setMethod("$<-", signature(x = "DataFrame"), function(x, name, value) { stopifnot(class(value) == "Column" || is.null(value)) @@ -923,6 +955,7 @@ setMethod("$<-", signature(x = "DataFrame"), }) #' @rdname select +#' @name [[ setMethod("[[", signature(x = "DataFrame"), function(x, i) { if (is.numeric(i)) { @@ -933,6 +966,7 @@ setMethod("[[", signature(x = "DataFrame"), }) #' @rdname select +#' @name [ setMethod("[", signature(x = "DataFrame", i = "missing"), function(x, i, j, ...) { if (is.numeric(j)) { @@ -1008,6 +1042,7 @@ setMethod("select", #' @param ... Additional expressions #' @return A DataFrame #' @rdname selectExpr +#' @name selectExpr #' @export #' @examples #'\dontrun{ @@ -1034,6 +1069,8 @@ setMethod("selectExpr", #' @param col A Column expression. #' @return A DataFrame with the new column added. #' @rdname withColumn +#' @name withColumn +#' @aliases mutate #' @export #' @examples #'\dontrun{ @@ -1057,7 +1094,7 @@ setMethod("withColumn", #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. #' @rdname withColumn -#' @aliases withColumn +#' @name mutate #' @export #' @examples #'\dontrun{ @@ -1094,6 +1131,7 @@ setMethod("mutate", #' @param newCol The new column name. #' @return A DataFrame with the column name changed. #' @rdname withColumnRenamed +#' @name withColumnRenamed #' @export #' @examples #'\dontrun{ @@ -1124,6 +1162,7 @@ setMethod("withColumnRenamed", #' @param newCol A named pair of the form new_column_name = existing_column #' @return A DataFrame with the column name changed. #' @rdname withColumnRenamed +#' @name rename #' @aliases withColumnRenamed #' @export #' @examples @@ -1165,6 +1204,8 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param ... Additional sorting fields #' @return A DataFrame where all elements are sorted. #' @rdname arrange +#' @name arrange +#' @aliases orderby #' @export #' @examples #'\dontrun{ @@ -1191,7 +1232,7 @@ setMethod("arrange", }) #' @rdname arrange -#' @aliases orderBy,DataFrame,function-method +#' @name orderby setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { @@ -1207,6 +1248,7 @@ setMethod("orderBy", #' or a string containing a SQL statement #' @return A DataFrame containing only the rows that meet the condition. #' @rdname filter +#' @name filter #' @export #' @examples #'\dontrun{ @@ -1228,7 +1270,7 @@ setMethod("filter", }) #' @rdname filter -#' @aliases where,DataFrame,function-method +#' @name where setMethod("where", signature(x = "DataFrame", condition = "characterOrColumn"), function(x, condition) { @@ -1247,6 +1289,7 @@ setMethod("where", #' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". #' @return A DataFrame containing the result of the join operation. #' @rdname join +#' @name join #' @export #' @examples #'\dontrun{ @@ -1279,8 +1322,9 @@ setMethod("join", dataFrame(sdf) }) -#' rdname merge -#' aliases join +#' @rdname merge +#' @name merge +#' @aliases join setMethod("merge", signature(x = "DataFrame", y = "DataFrame"), function(x, y, joinExpr = NULL, joinType = NULL, ...) { @@ -1298,6 +1342,7 @@ setMethod("merge", #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. #' @rdname unionAll +#' @name unionAll #' @export #' @examples #'\dontrun{ @@ -1319,6 +1364,7 @@ setMethod("unionAll", #' @description Returns a new DataFrame containing rows of all parameters. # #' @rdname rbind +#' @name rbind #' @aliases unionAll setMethod("rbind", signature(... = "DataFrame"), @@ -1339,6 +1385,7 @@ setMethod("rbind", #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the intersect. #' @rdname intersect +#' @name intersect #' @export #' @examples #'\dontrun{ @@ -1364,6 +1411,7 @@ setMethod("intersect", #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the except operation. #' @rdname except +#' @name except #' @export #' @examples #'\dontrun{ @@ -1403,6 +1451,8 @@ setMethod("except", #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' #' @rdname write.df +#' @name write.df +#' @aliases saveDF #' @export #' @examples #'\dontrun{ @@ -1435,7 +1485,7 @@ setMethod("write.df", }) #' @rdname write.df -#' @aliases saveDF +#' @name saveDF #' @export setMethod("saveDF", signature(df = "DataFrame", path = "character"), @@ -1466,6 +1516,7 @@ setMethod("saveDF", #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' #' @rdname saveAsTable +#' @name saveAsTable #' @export #' @examples #'\dontrun{ @@ -1505,6 +1556,8 @@ setMethod("saveAsTable", #' @param ... Additional expressions #' @return A DataFrame #' @rdname describe +#' @name describe +#' @aliases summary #' @export #' @examples #'\dontrun{ @@ -1525,6 +1578,7 @@ setMethod("describe", }) #' @rdname describe +#' @name describe setMethod("describe", signature(x = "DataFrame"), function(x) { @@ -1538,7 +1592,7 @@ setMethod("describe", #' @description Computes statistics for numeric columns of the DataFrame #' #' @rdname summary -#' @aliases describe +#' @name summary setMethod("summary", signature(x = "DataFrame"), function(x) { @@ -1562,6 +1616,8 @@ setMethod("summary", #' @return A DataFrame #' #' @rdname nafunctions +#' @name dropna +#' @aliases na.omit #' @export #' @examples #'\dontrun{ @@ -1588,7 +1644,8 @@ setMethod("dropna", dataFrame(sdf) }) -#' @aliases dropna +#' @rdname nafunctions +#' @name na.omit #' @export setMethod("na.omit", signature(x = "DataFrame"), @@ -1615,6 +1672,7 @@ setMethod("na.omit", #' @return A DataFrame #' #' @rdname nafunctions +#' @name fillna #' @export #' @examples #'\dontrun{ @@ -1685,6 +1743,7 @@ setMethod("fillna", #' occurrences will have zero as their counts. #' #' @rdname statfunctions +#' @name crosstab #' @export #' @examples #' \dontrun{ diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index a1f50c383367..4805096f3f9c 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -24,10 +24,9 @@ setOldClass("jobj") #' @title S4 class that represents a DataFrame column #' @description The column class supports unary, binary operations on DataFrame columns - #' @rdname column #' -#' @param jc reference to JVM DataFrame column +#' @slot jc reference to JVM DataFrame column #' @export setClass("Column", slots = list(jc = "jobj")) @@ -46,6 +45,7 @@ col <- function(x) { } #' @rdname show +#' @name show setMethod("show", "Column", function(object) { cat("Column", callJMethod(object@jc, "toString"), "\n") @@ -122,8 +122,11 @@ createMethods() #' alias #' #' Set a new name for a column - -#' @rdname column +#' +#' @rdname alias +#' @name alias +#' @family colum_func +#' @export setMethod("alias", signature(object = "Column"), function(object, data) { @@ -138,7 +141,9 @@ setMethod("alias", #' #' An expression that returns a substring. #' -#' @rdname column +#' @rdname substr +#' @name substr +#' @family colum_func #' #' @param start starting position #' @param stop ending position @@ -152,7 +157,9 @@ setMethod("substr", signature(x = "Column"), #' #' Test if the column is between the lower bound and upper bound, inclusive. #' -#' @rdname column +#' @rdname between +#' @name between +#' @family colum_func #' #' @param bounds lower and upper bounds setMethod("between", signature(x = "Column"), @@ -167,7 +174,9 @@ setMethod("between", signature(x = "Column"), #' Casts the column to a different data type. #' -#' @rdname column +#' @rdname cast +#' @name cast +#' @family colum_func #' #' @examples \dontrun{ #' cast(df$age, "string") @@ -189,11 +198,15 @@ setMethod("cast", #' Match a column with given values. #' -#' @rdname column +#' @rdname match +#' @name %in% +#' @aliases %in% #' @return a matched values as a result of comparing with given values. -#' @examples \dontrun{ -#' filter(df, "age in (10, 30)") -#' where(df, df$age %in% c(10, 30)) +#' @export +#' @examples +#' \dontrun{ +#' filter(df, "age in (10, 30)") +#' where(df, df$age %in% c(10, 30)) #' } setMethod("%in%", signature(x = "Column"), @@ -208,7 +221,10 @@ setMethod("%in%", #' If values in the specified column are null, returns the value. #' Can be used in conjunction with `when` to specify a default value for expressions. #' -#' @rdname column +#' @rdname otherwise +#' @name otherwise +#' @family colum_func +#' @export setMethod("otherwise", signature(x = "Column", value = "ANY"), function(x, value) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 610a8c31223c..a829d46c1894 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -441,7 +441,7 @@ setGeneric("filter", function(x, condition) { standardGeneric("filter") }) #' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) -#' @rdname DataFrame +#' @rdname groupBy #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) From 57b960bf3706728513f9e089455a533f0244312e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 25 Aug 2015 08:32:20 +0100 Subject: [PATCH 0402/1824] [SPARK-6196] [BUILD] Remove MapR profiles in favor of hadoop-provided Follow up to https://github.com/apache/spark/pull/7047 pwendell mentioned that MapR should use `hadoop-provided` now, and indeed the new build script does not produce `mapr3`/`mapr4` artifacts anymore. Hence the action seems to be to remove the profiles, which are now not used. CC trystanleftwich Author: Sean Owen Closes #8338 from srowen/SPARK-6196. --- pom.xml | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/pom.xml b/pom.xml index d5945f2546d3..0716016523ee 100644 --- a/pom.xml +++ b/pom.xml @@ -2386,44 +2386,6 @@ - - mapr3 - - 1.0.3-mapr-3.0.3 - 2.4.1-mapr-1408 - 0.98.4-mapr-1408 - 3.4.5-mapr-1406 - - - - - mapr4 - - 2.4.1-mapr-1408 - 2.4.1-mapr-1408 - 0.98.4-mapr-1408 - 3.4.5-mapr-1406 - - - - org.apache.curator - curator-recipes - ${curator.version} - - - org.apache.zookeeper - zookeeper - - - - - org.apache.zookeeper - zookeeper - 3.4.5-mapr-1406 - - - - hive-thriftserver From 1fc37581a52530bac5d555dbf14927a5780c3b75 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 25 Aug 2015 00:35:51 -0700 Subject: [PATCH 0403/1824] [SPARK-10210] [STREAMING] Filter out non-existent blocks before creating BlockRDD When write ahead log is not enabled, a recovered streaming driver still tries to run jobs using pre-failure block ids, and fails as the block do not exists in-memory any more (and cannot be recovered as receiver WAL is not enabled). This occurs because the driver-side WAL of ReceivedBlockTracker is recovers that past block information, and ReceiveInputDStream creates BlockRDDs even if those blocks do not exist. The solution in this PR is to filter out block ids that do not exist before creating the BlockRDD. In addition, it adds unit tests to verify other logic in ReceiverInputDStream. Author: Tathagata Das Closes #8405 from tdas/SPARK-10210. --- .../dstream/ReceiverInputDStream.scala | 10 +- .../rdd/WriteAheadLogBackedBlockRDD.scala | 2 +- .../streaming/ReceiverInputDStreamSuite.scala | 156 ++++++++++++++++++ 3 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index a15800917c6f..6c139f32da31 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -116,7 +116,15 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont logWarning("Some blocks have Write Ahead Log information; this is unexpected") } } - new BlockRDD[T](ssc.sc, blockIds) + val validBlockIds = blockIds.filter { id => + ssc.sparkContext.env.blockManager.master.contains(id) + } + if (validBlockIds.size != blockIds.size) { + logWarning("Some blocks could not be recovered as they were not found in memory. " + + "To prevent such data loss, enabled Write Ahead Log (see programming guide " + + "for more details.") + } + new BlockRDD[T](ssc.sc, validBlockIds) } } else { // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 620b8a36a2ba..e081ffe46f50 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -75,7 +75,7 @@ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, @transient blockIds: Array[BlockId], - @transient walRecordHandles: Array[WriteAheadLogRecordHandle], + @transient val walRecordHandles: Array[WriteAheadLogRecordHandle], @transient isBlockIdValid: Array[Boolean] = Array.empty, storeInBlockManager: Boolean = false, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala new file mode 100644 index 000000000000..6d388d9624d9 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{BlockManagerBasedStoreResult, Receiver, WriteAheadLogBasedStoreResult} +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} +import org.apache.spark.{SparkConf, SparkEnv} + +class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { + + override def afterAll(): Unit = { + StreamingContext.getActive().map { _.stop() } + } + + testWithoutWAL("createBlockRDD creates empty BlockRDD when no block info") { receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithoutWAL("createBlockRDD creates correct BlockRDD with block info") { receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = false) } + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.forall(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + testWithoutWAL("createBlockRDD filters non-existent blocks before creating BlockRDD") { + receiverStream => + val presentBlockInfos = Seq.fill(2)(createBlockInfo(withWALInfo = false, createBlock = true)) + val absentBlockInfos = Seq.fill(3)(createBlockInfo(withWALInfo = false, createBlock = false)) + val blockInfos = presentBlockInfos ++ absentBlockInfos + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.exists(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + require(blockIds.exists(blockId => !SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === presentBlockInfos.map { _.blockId}) + } + + testWithWAL("createBlockRDD creates empty WALBackedBlockRDD when no block info") { + receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithWAL( + "createBlockRDD creates correct WALBackedBlockRDD with all block info having WAL info") { + receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = true) } + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[WriteAheadLogBackedBlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + assert(blockRDD.walRecordHandles.toSeq === blockInfos.map { _.walRecordHandleOption.get }) + } + + testWithWAL("createBlockRDD creates BlockRDD when some block info dont have WAL info") { + receiverStream => + val blockInfos1 = Seq.fill(2) { createBlockInfo(withWALInfo = true) } + val blockInfos2 = Seq.fill(3) { createBlockInfo(withWALInfo = false) } + val blockInfos = blockInfos1 ++ blockInfos2 + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + + private def testWithoutWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"Without WAL enabled: $msg") { + runTest(enableWAL = false, body) + } + } + + private def testWithWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"With WAL enabled: $msg") { + runTest(enableWAL = true, body) + } + } + + private def runTest(enableWAL: Boolean, body: ReceiverInputDStream[_] => Unit): Unit = { + val conf = new SparkConf() + conf.setMaster("local[4]").setAppName("ReceiverInputDStreamSuite") + conf.set(WriteAheadLogUtils.RECEIVER_WAL_ENABLE_CONF_KEY, enableWAL.toString) + require(WriteAheadLogUtils.enableReceiverLog(conf) === enableWAL) + val ssc = new StreamingContext(conf, Seconds(1)) + val receiverStream = new ReceiverInputDStream[Int](ssc) { + override def getReceiver(): Receiver[Int] = null + } + withStreamingContext(ssc) { ssc => + body(receiverStream) + } + } + + /** + * Create a block info for input to the ReceiverInputDStream.createBlockRDD + * @param withWALInfo Create block with WAL info in it + * @param createBlock Actually create the block in the BlockManager + * @return + */ + private def createBlockInfo( + withWALInfo: Boolean, + createBlock: Boolean = true): ReceivedBlockInfo = { + val blockId = new StreamBlockId(0, Random.nextLong()) + if (createBlock) { + SparkEnv.get.blockManager.putSingle(blockId, 1, StorageLevel.MEMORY_ONLY, tellMaster = true) + require(SparkEnv.get.blockManager.master.contains(blockId)) + } + val storeResult = if (withWALInfo) { + new WriteAheadLogBasedStoreResult(blockId, None, new WriteAheadLogRecordHandle { }) + } else { + new BlockManagerBasedStoreResult(blockId, None) + } + new ReceivedBlockInfo(0, None, None, storeResult) + } +} From 2f493f7e3924b769160a16f73cccbebf21973b91 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 25 Aug 2015 16:00:44 +0800 Subject: [PATCH 0404/1824] [SPARK-10177] [SQL] fix reading Timestamp in parquet from Hive We misunderstood the Julian days and nanoseconds of the day in parquet (as TimestampType) from Hive/Impala, they are overlapped, so can't be added together directly. In order to avoid the confusing rounding when do the converting, we use `2440588` as the Julian Day of epoch of unix timestamp (which should be 2440587.5). Author: Davies Liu Author: Cheng Lian Closes #8400 from davies/timestamp_parquet. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 7 ++++--- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 13 +++++++++---- .../sql/hive/ParquetHiveCompatibilitySuite.scala | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 672620460c3c..d652fce3fd9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -37,7 +37,8 @@ object DateTimeUtils { type SQLTimestamp = Long // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian - final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5 + // it's 2440587.5, rounding up to compatible with Hive + final val JULIAN_DAY_OF_EPOCH = 2440588 final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L @@ -183,7 +184,7 @@ object DateTimeUtils { */ def fromJulianDay(day: Int, nanoseconds: Long): SQLTimestamp = { // use Long to avoid rounding errors - val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY - SECONDS_PER_DAY / 2 + val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY seconds * MICROS_PER_SECOND + nanoseconds / 1000L } @@ -191,7 +192,7 @@ object DateTimeUtils { * Returns Julian day and nanoseconds in a day from the number of microseconds */ def toJulianDay(us: SQLTimestamp): (Int, Long) = { - val seconds = us / MICROS_PER_SECOND + SECONDS_PER_DAY / 2 + val seconds = us / MICROS_PER_SECOND val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH val secondsInDay = seconds % SECONDS_PER_DAY val nanos = (us % MICROS_PER_SECOND) * 1000L diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index d18fa4df1335..1596bb79fa94 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -49,13 +49,18 @@ class DateTimeUtilsSuite extends SparkFunSuite { test("us and julian day") { val (d, ns) = toJulianDay(0) assert(d === JULIAN_DAY_OF_EPOCH) - assert(ns === SECONDS_PER_DAY / 2 * NANOS_PER_SECOND) + assert(ns === 0) assert(fromJulianDay(d, ns) == 0L) - val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100) + val t = Timestamp.valueOf("2015-06-11 10:10:10.100") val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) - val t2 = toJavaTimestamp(fromJulianDay(d1, ns1)) - assert(t.equals(t2)) + val t1 = toJavaTimestamp(fromJulianDay(d1, ns1)) + assert(t.equals(t1)) + + val t2 = Timestamp.valueOf("2015-06-11 20:10:10.100") + val (d2, ns2) = toJulianDay(fromJavaTimestamp(t2)) + val t22 = toJavaTimestamp(fromJulianDay(d2, ns2)) + assert(t2.equals(t22)) } test("SPARK-6785: java date conversion before and after epoch") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index bc30180cf091..91d7a48208e8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -113,7 +113,7 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with Before "BOOLEAN", "TINYINT", "SMALLINT", "INT", "BIGINT", "FLOAT", "DOUBLE", "STRING") } - ignore("SPARK-10177 timestamp") { + test("SPARK-10177 timestamp") { testParquetHiveCompatibility(Row(Timestamp.valueOf("2015-08-24 00:31:00")), "TIMESTAMP") } From 7bc9a8c6249300ded31ea931c463d0a8f798e193 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 25 Aug 2015 01:06:36 -0700 Subject: [PATCH 0405/1824] [SPARK-10195] [SQL] Data sources Filter should not expose internal types Spark SQL's data sources API exposes Catalyst's internal types through its Filter interfaces. This is a problem because types like UTF8String are not stable developer APIs and should not be exposed to third-parties. This issue caused incompatibilities when upgrading our `spark-redshift` library to work against Spark 1.5.0. To avoid these issues in the future we should only expose public types through these Filter objects. This patch accomplishes this by using CatalystTypeConverters to add the appropriate conversions. Author: Josh Rosen Closes #8403 from JoshRosen/datasources-internal-vs-external-types. --- .../datasources/DataSourceStrategy.scala | 67 ++++++++++--------- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../datasources/parquet/ParquetFilters.scala | 19 +++--- .../spark/sql/sources/FilteredScanSuite.scala | 7 ++ 4 files changed, 54 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2a4c40db8bb6..6c1ef6a6df88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{InternalRow, expressions} +import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical @@ -344,45 +345,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { */ protected[sql] def selectFilters(filters: Seq[Expression]) = { def translate(predicate: Expression): Option[Filter] = predicate match { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => - Some(sources.EqualTo(a.name, v)) - case expressions.EqualTo(Literal(v, _), a: Attribute) => - Some(sources.EqualTo(a.name, v)) - - case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => - Some(sources.EqualNullSafe(a.name, v)) - case expressions.EqualNullSafe(Literal(v, _), a: Attribute) => - Some(sources.EqualNullSafe(a.name, v)) - - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => - Some(sources.GreaterThan(a.name, v)) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => - Some(sources.LessThan(a.name, v)) - - case expressions.LessThan(a: Attribute, Literal(v, _)) => - Some(sources.LessThan(a.name, v)) - case expressions.LessThan(Literal(v, _), a: Attribute) => - Some(sources.GreaterThan(a.name, v)) - - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => - Some(sources.GreaterThanOrEqual(a.name, v)) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => - Some(sources.LessThanOrEqual(a.name, v)) - - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => - Some(sources.LessThanOrEqual(a.name, v)) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => - Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.EqualTo(a: Attribute, Literal(v, t)) => + Some(sources.EqualTo(a.name, convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), a: Attribute) => + Some(sources.EqualTo(a.name, convertToScala(v, t))) + + case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) => + Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) + case expressions.EqualNullSafe(Literal(v, t), a: Attribute) => + Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) + + case expressions.GreaterThan(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThan(a.name, convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), a: Attribute) => + Some(sources.LessThan(a.name, convertToScala(v, t))) + + case expressions.LessThan(a: Attribute, Literal(v, t)) => + Some(sources.LessThan(a.name, convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), a: Attribute) => + Some(sources.GreaterThan(a.name, convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) case expressions.InSet(a: Attribute, set) => - Some(sources.In(a.name, set.toArray)) + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(a.name, set.toArray.map(toScala))) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) - Some(sources.In(a.name, hSet.toArray)) + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(a.name, hSet.toArray.map(toScala))) case expressions.IsNull(a: Attribute) => Some(sources.IsNull(a.name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e537d631f455..730d88b024cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -262,7 +262,7 @@ private[sql] class JDBCRDD( * Converts value to SQL expression. */ private def compileValue(value: Any): Any = value match { - case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" + case stringValue: String => s"'${escapeSql(stringValue)}'" case _ => value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index c74c8388632f..c6b3fe7900da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -32,7 +32,6 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" @@ -65,7 +64,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), @@ -86,7 +85,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -104,7 +103,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.lt(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -121,7 +121,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.ltEq(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -138,7 +139,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.gt(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -155,7 +157,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.gtEq(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -177,7 +180,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), - SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes)))) + SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))))) case BinaryType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index c81c3d398280..68ce37c00077 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources import scala.language.existentials import org.apache.spark.rdd.RDD +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -78,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case StringStartsWith("c", v) => _.startsWith(v) case StringEndsWith("c", v) => _.endsWith(v) case StringContains("c", v) => _.contains(v) + case EqualTo("c", v: String) => _.equals(v) + case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters") + case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s) case _ => (c: String) => true } @@ -237,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution From 0e6368ffaec1965d0c7f89420e04a974675c7f6e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 25 Aug 2015 16:19:34 +0800 Subject: [PATCH 0406/1824] [SPARK-10197] [SQL] Add null check in wrapperFor (inside HiveInspectors). https://issues.apache.org/jira/browse/SPARK-10197 Author: Yin Huai Closes #8407 from yhuai/ORCSPARK-10197. --- .../spark/sql/hive/HiveInspectors.scala | 29 +++++++++++++++---- .../spark/sql/hive/orc/OrcSourceSuite.scala | 29 +++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 9824dad23959..64fffdbf9b02 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -370,17 +370,36 @@ private[hive] trait HiveInspectors { protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => (o: Any) => - val s = o.asInstanceOf[UTF8String].toString - new HiveVarchar(s, s.size) + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) + } else { + null + } case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => + if (o != null) { + HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + } else { + null + } case _: JavaDateObjectInspector => - (o: Any) => DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + } else { + null + } case _: JavaTimestampObjectInspector => - (o: Any) => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) + } else { + null + } case soi: StandardStructObjectInspector => val schema = dataType.asInstanceOf[StructType] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 82e08caf4645..80c38084f293 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -121,6 +121,35 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM normal_orc_as_source"), (6 to 10).map(i => Row(i, s"part-$i"))) } + + test("write null values") { + sql("DROP TABLE IF EXISTS orcNullValues") + + val df = sql( + """ + |SELECT + | CAST(null as TINYINT), + | CAST(null as SMALLINT), + | CAST(null as INT), + | CAST(null as BIGINT), + | CAST(null as FLOAT), + | CAST(null as DOUBLE), + | CAST(null as DECIMAL(7,2)), + | CAST(null as TIMESTAMP), + | CAST(null as DATE), + | CAST(null as STRING), + | CAST(null as VARCHAR(10)) + |FROM orc_temp_table limit 1 + """.stripMargin) + + df.write.format("orc").saveAsTable("orcNullValues") + + checkAnswer( + sql("SELECT * FROM orcNullValues"), + Row.fromSeq(Seq.fill(11)(null))) + + sql("DROP TABLE IF EXISTS orcNullValues") + } } class OrcSourceSuite extends OrcSuite { From 5c14890159a5711072bf395f662b2433a389edf9 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 25 Aug 2015 11:48:55 +0100 Subject: [PATCH 0407/1824] [DOC] add missing parameters in SparkContext.scala for scala doc Author: Zhang, Liye Closes #8412 from liyezhang556520/minorDoc. --- .../scala/org/apache/spark/SparkContext.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1ddaca8a5ba8..9849aff85d72 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -114,6 +114,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * :: DeveloperApi :: * Alternative constructor for setting preferred locations where Spark will create executors. * + * @param config a [[org.apache.spark.SparkConf]] object specifying other Spark parameters * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] * from a list of input files or InputFormats for the application. @@ -145,6 +146,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. + * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. + * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] + * from a list of input files or InputFormats for the application. */ def this( master: String, @@ -841,6 +845,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note Small files are preferred, large file is also allowable, but may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def wholeTextFiles( @@ -889,6 +896,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note Small files are preferred; very large files may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental @@ -918,8 +928,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * '''Note:''' We ensure that the byte array for each record in the resulting RDD * has the provided record length. * - * @param path Directory to the input data files + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param recordLength The length at which to split the records + * @param conf Configuration for setting up the dataset. + * * @return An RDD of data with values, represented as byte arrays */ @Experimental From 7f1e507bf7e82bff323c5dec3c1ee044687c4173 Mon Sep 17 00:00:00 2001 From: ehnalis Date: Tue, 25 Aug 2015 12:30:06 +0100 Subject: [PATCH 0408/1824] Fixed a typo in DAGScheduler. Author: ehnalis Closes #8308 from ehnalis/master. --- .../apache/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 684db6646765..daf9b0f95273 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -152,17 +152,24 @@ class DAGScheduler( // may lead to more delay in scheduling if those locations are busy. private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 - // Called by TaskScheduler to report task's starting. + /** + * Called by the TaskSetManager to report task's starting. + */ def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) } - // Called to report that a task has completed and results are being fetched remotely. + /** + * Called by the TaskSetManager to report that a task has completed + * and results are being fetched remotely. + */ def taskGettingResult(taskInfo: TaskInfo) { eventProcessLoop.post(GettingResultEvent(taskInfo)) } - // Called by TaskScheduler to report task completions or failures. + /** + * Called by the TaskSetManager to report task completions or failures. + */ def taskEnded( task: Task[_], reason: TaskEndReason, @@ -188,18 +195,24 @@ class DAGScheduler( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } - // Called by TaskScheduler when an executor fails. + /** + * Called by TaskScheduler implementation when an executor fails. + */ def executorLost(execId: String): Unit = { eventProcessLoop.post(ExecutorLost(execId)) } - // Called by TaskScheduler when a host is added + /** + * Called by TaskScheduler implementation when a host is added. + */ def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } - // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or - // cancellation of the job itself. + /** + * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or + * cancellation of the job itself. + */ def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } From 69c9c177160e32a2fbc9b36ecc52156077fca6fc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 25 Aug 2015 12:33:13 +0100 Subject: [PATCH 0409/1824] [SPARK-9613] [CORE] Ban use of JavaConversions and migrate all existing uses to JavaConverters Replace `JavaConversions` implicits with `JavaConverters` Most occurrences I've seen so far are necessary conversions; a few have been avoidable. None are in critical code as far as I see, yet. Author: Sean Owen Closes #8033 from srowen/SPARK-9613. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 4 +- .../org/apache/spark/MapOutputTracker.scala | 4 +- .../scala/org/apache/spark/SSLOptions.scala | 11 +- .../scala/org/apache/spark/SparkContext.scala | 4 +- .../scala/org/apache/spark/TestUtils.scala | 9 +- .../apache/spark/api/java/JavaHadoopRDD.scala | 4 +- .../spark/api/java/JavaNewHadoopRDD.scala | 4 +- .../apache/spark/api/java/JavaPairRDD.scala | 19 ++- .../apache/spark/api/java/JavaRDDLike.scala | 75 +++++------- .../spark/api/java/JavaSparkContext.scala | 20 ++-- .../spark/api/python/PythonHadoopUtil.scala | 28 ++--- .../apache/spark/api/python/PythonRDD.scala | 26 ++--- .../apache/spark/api/python/PythonUtils.scala | 15 ++- .../api/python/PythonWorkerFactory.scala | 11 +- .../apache/spark/api/python/SerDeUtil.scala | 3 +- .../WriteInputFormatTestDataGenerator.scala | 8 +- .../scala/org/apache/spark/api/r/RRDD.scala | 13 ++- .../scala/org/apache/spark/api/r/RUtils.scala | 5 +- .../scala/org/apache/spark/api/r/SerDe.scala | 4 +- .../spark/broadcast/TorrentBroadcast.scala | 4 +- .../spark/deploy/ExternalShuffleService.scala | 8 +- .../apache/spark/deploy/PythonRunner.scala | 4 +- .../apache/spark/deploy/RPackageUtils.scala | 4 +- .../org/apache/spark/deploy/RRunner.scala | 4 +- .../spark/deploy/SparkCuratorUtil.scala | 4 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 19 +-- .../spark/deploy/SparkSubmitArguments.scala | 6 +- .../master/ZooKeeperPersistenceEngine.scala | 6 +- .../spark/deploy/worker/CommandUtils.scala | 5 +- .../spark/deploy/worker/DriverRunner.scala | 8 +- .../spark/deploy/worker/ExecutorRunner.scala | 7 +- .../apache/spark/deploy/worker/Worker.scala | 1 - .../org/apache/spark/executor/Executor.scala | 6 +- .../spark/executor/ExecutorSource.scala | 4 +- .../spark/executor/MesosExecutorBackend.scala | 6 +- .../spark/input/PortableDataStream.scala | 11 +- .../input/WholeTextFileInputFormat.scala | 8 +- .../spark/launcher/WorkerCommandBuilder.scala | 4 +- .../apache/spark/metrics/MetricsConfig.scala | 22 ++-- .../network/netty/NettyBlockRpcServer.scala | 4 +- .../netty/NettyBlockTransferService.scala | 6 +- .../apache/spark/network/nio/Connection.scala | 4 +- .../spark/partial/GroupedCountEvaluator.scala | 10 +- .../spark/partial/GroupedMeanEvaluator.scala | 10 +- .../spark/partial/GroupedSumEvaluator.scala | 10 +- .../apache/spark/rdd/PairRDDFunctions.scala | 6 +- .../scala/org/apache/spark/rdd/PipedRDD.scala | 6 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 4 +- .../spark/scheduler/InputFormatInfo.scala | 4 +- .../org/apache/spark/scheduler/Pool.scala | 10 +- .../mesos/CoarseMesosSchedulerBackend.scala | 20 ++-- .../mesos/MesosClusterPersistenceEngine.scala | 4 +- .../cluster/mesos/MesosClusterScheduler.scala | 14 +-- .../cluster/mesos/MesosSchedulerBackend.scala | 22 ++-- .../cluster/mesos/MesosSchedulerUtils.scala | 25 ++-- .../spark/serializer/KryoSerializer.scala | 10 +- .../shuffle/FileShuffleBlockResolver.scala | 8 +- .../storage/BlockManagerMasterEndpoint.scala | 8 +- .../org/apache/spark/util/AkkaUtils.scala | 4 +- .../org/apache/spark/util/ListenerBus.scala | 7 +- .../spark/util/MutableURLClassLoader.scala | 2 - .../spark/util/TimeStampedHashMap.scala | 10 +- .../spark/util/TimeStampedHashSet.scala | 4 +- .../scala/org/apache/spark/util/Utils.scala | 20 ++-- .../apache/spark/util/collection/Utils.scala | 4 +- .../java/org/apache/spark/JavaAPISuite.java | 6 +- .../org/apache/spark/SparkConfSuite.scala | 7 +- .../spark/deploy/LogUrlsStandaloneSuite.scala | 1 - .../spark/deploy/RPackageUtilsSuite.scala | 8 +- .../deploy/worker/ExecutorRunnerTest.scala | 5 +- .../spark/scheduler/SparkListenerSuite.scala | 9 +- .../mesos/MesosSchedulerBackendSuite.scala | 15 +-- .../serializer/KryoSerializerSuite.scala | 3 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 21 ++-- .../spark/examples/CassandraCQLTest.scala | 15 +-- .../apache/spark/examples/CassandraTest.scala | 6 +- .../spark/examples/DriverSubmissionTest.scala | 6 +- .../pythonconverters/AvroConverters.scala | 16 +-- .../CassandraConverters.scala | 14 ++- .../pythonconverters/HBaseConverters.scala | 5 +- .../streaming/flume/sink/SparkSinkSuite.scala | 4 +- .../streaming/flume/EventTransformer.scala | 4 +- .../streaming/flume/FlumeBatchFetcher.scala | 3 +- .../streaming/flume/FlumeInputDStream.scala | 7 +- .../flume/FlumePollingInputDStream.scala | 6 +- .../streaming/flume/FlumeTestUtils.scala | 10 +- .../spark/streaming/flume/FlumeUtils.scala | 8 +- .../flume/PollingFlumeTestUtils.scala | 16 ++- .../flume/FlumePollingStreamSuite.scala | 8 +- .../streaming/flume/FlumeStreamSuite.scala | 2 +- .../streaming/kafka/KafkaTestUtils.scala | 4 +- .../spark/streaming/kafka/KafkaUtils.scala | 35 +++--- .../spark/streaming/zeromq/ZeroMQUtils.scala | 15 ++- .../kinesis/KinesisBackedBlockRDD.scala | 4 +- .../streaming/kinesis/KinesisReceiver.scala | 4 +- .../streaming/kinesis/KinesisTestUtils.scala | 3 +- .../kinesis/KinesisReceiverSuite.scala | 12 +- .../mllib/util/LinearDataGenerator.scala | 4 +- .../ml/classification/JavaOneVsRestSuite.java | 7 +- .../LogisticRegressionSuite.scala | 4 +- .../spark/mllib/classification/SVMSuite.scala | 4 +- .../optimization/GradientDescentSuite.scala | 4 +- .../spark/mllib/recommendation/ALSSuite.scala | 4 +- project/SparkBuild.scala | 8 +- python/pyspark/sql/column.py | 12 ++ python/pyspark/sql/dataframe.py | 4 +- scalastyle-config.xml | 7 ++ .../main/scala/org/apache/spark/sql/Row.scala | 12 +- .../spark/sql/catalyst/analysis/Catalog.scala | 4 +- .../spark/sql/DataFrameNaFunctions.scala | 8 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../org/apache/spark/sql/GroupedData.scala | 4 +- .../scala/org/apache/spark/sql/SQLConf.scala | 13 ++- .../org/apache/spark/sql/SQLContext.scala | 8 +- .../datasources/ResolvedDataSource.scala | 4 +- .../parquet/CatalystReadSupport.scala | 8 +- .../parquet/CatalystRowConverter.scala | 4 +- .../parquet/CatalystSchemaConverter.scala | 4 +- .../datasources/parquet/ParquetRelation.scala | 13 ++- .../parquet/ParquetTypesConverter.scala | 4 +- .../joins/ShuffledHashOuterJoin.scala | 6 +- .../spark/sql/execution/pythonUDFs.scala | 11 +- .../apache/spark/sql/JavaDataFrameSuite.java | 8 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 6 +- .../org/apache/spark/sql/QueryTest.scala | 4 +- .../ParquetAvroCompatibilitySuite.scala | 3 +- .../parquet/ParquetCompatibilityTest.scala | 7 +- .../datasources/parquet/ParquetIOSuite.scala | 25 ++-- .../SparkExecuteStatementOperation.scala | 10 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 16 +-- .../thriftserver/SparkSQLCLIService.scala | 6 +- .../hive/thriftserver/SparkSQLDriver.scala | 14 +-- .../sql/hive/thriftserver/SparkSQLEnv.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../spark/sql/hive/HiveInspectors.scala | 40 +++---- .../spark/sql/hive/HiveMetastoreCatalog.scala | 12 +- .../org/apache/spark/sql/hive/HiveQl.scala | 110 ++++++++++-------- .../org/apache/spark/sql/hive/HiveShim.scala | 5 +- .../spark/sql/hive/client/ClientWrapper.scala | 27 ++--- .../spark/sql/hive/client/HiveShim.scala | 14 +-- .../execution/DescribeHiveTableCommand.scala | 8 +- .../sql/hive/execution/HiveTableScan.scala | 9 +- .../hive/execution/InsertIntoHiveTable.scala | 12 +- .../hive/execution/ScriptTransformation.scala | 12 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 9 +- .../spark/sql/hive/orc/OrcRelation.scala | 8 +- .../apache/spark/sql/hive/test/TestHive.scala | 11 +- .../spark/sql/hive/client/FiltersSuite.scala | 4 +- .../sql/hive/execution/HiveUDFSuite.scala | 29 +++-- .../sql/hive/execution/PruningSuite.scala | 7 +- .../sql/hive/execution/SQLQuerySuite.scala | 4 +- .../sql/sources/hadoopFsRelationSuites.scala | 8 +- .../streaming/api/java/JavaDStreamLike.scala | 12 +- .../streaming/api/java/JavaPairDStream.scala | 28 ++--- .../api/java/JavaStreamingContext.scala | 32 ++--- .../streaming/api/python/PythonDStream.scala | 5 +- .../spark/streaming/receiver/Receiver.scala | 6 +- .../streaming/scheduler/JobScheduler.scala | 4 +- .../scheduler/ReceivedBlockTracker.scala | 4 +- .../util/FileBasedWriteAheadLog.scala | 4 +- .../spark/streaming/JavaTestUtils.scala | 24 ++-- .../streaming/util/WriteAheadLogSuite.scala | 4 +- .../spark/tools/GenerateMIMAIgnore.scala | 6 +- .../org/apache/spark/deploy/yarn/Client.scala | 13 ++- .../spark/deploy/yarn/ExecutorRunnable.scala | 24 ++-- .../spark/deploy/yarn/YarnAllocator.scala | 19 ++- .../spark/deploy/yarn/YarnRMClient.scala | 8 +- .../deploy/yarn/BaseYarnClusterSuite.scala | 6 +- .../spark/deploy/yarn/ClientSuite.scala | 8 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 5 +- 171 files changed, 863 insertions(+), 880 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 2389c28b2839..fdb309e365f6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -24,7 +24,7 @@ import scala.Option; import scala.Product2; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; import scala.collection.immutable.Map; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -160,7 +160,7 @@ public long getPeakMemoryUsedBytes() { */ @VisibleForTesting public void write(Iterator> records) throws IOException { - write(JavaConversions.asScalaIterator(records)); + write(JavaConverters.asScalaIteratorConverter(records).asScala()); } @Override diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 92218832d256..a38759278385 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,8 +21,8 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} -import scala.collection.JavaConversions._ import scala.reflect.ClassTag import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} @@ -398,7 +398,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { protected val mapStatuses: Map[Int, Array[MapStatus]] = - new ConcurrentHashMap[Int, Array[MapStatus]] + new ConcurrentHashMap[Int, Array[MapStatus]]().asScala } private[spark] object MapOutputTracker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 32df42d57dbd..3b9c885bf97a 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -17,9 +17,11 @@ package org.apache.spark -import java.io.{File, FileInputStream} -import java.security.{KeyStore, NoSuchAlgorithmException} -import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory} +import java.io.File +import java.security.NoSuchAlgorithmException +import javax.net.ssl.SSLContext + +import scala.collection.JavaConverters._ import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -79,7 +81,6 @@ private[spark] case class SSLOptions( * object. It can be used then to compose the ultimate Akka configuration. */ def createAkkaConfig: Option[Config] = { - import scala.collection.JavaConversions._ if (enabled) { Some(ConfigFactory.empty() .withValue("akka.remote.netty.tcp.security.key-store", @@ -97,7 +98,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.asJava)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9849aff85d72..f3da04a7f55d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -26,8 +26,8 @@ import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} -import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} @@ -1546,7 +1546,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getAllPools: Seq[Schedulable] = { assertNotStopped() // TODO(xiajunluan): We should take nested pools into account - taskScheduler.rootPool.schedulableQueue.toSeq + taskScheduler.rootPool.schedulableQueue.asScala.toSeq } /** diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index a1ebbecf93b7..888763a3e8eb 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -19,11 +19,12 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} +import java.nio.charset.StandardCharsets +import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -71,7 +72,7 @@ private[spark] object TestUtils { files.foreach { case (k, v) => val entry = new JarEntry(k) jarStream.putNextEntry(entry) - ByteStreams.copy(new ByteArrayInputStream(v.getBytes(UTF_8)), jarStream) + ByteStreams.copy(new ByteArrayInputStream(v.getBytes(StandardCharsets.UTF_8)), jarStream) } jarStream.close() jarFile.toURI.toURL @@ -125,7 +126,7 @@ private[spark] object TestUtils { } else { Seq() } - compiler.getTask(null, null, null, options, null, Seq(sourceFile)).call() + compiler.getTask(null, null, null, options.asJava, null, Arrays.asList(sourceFile)).call() val fileName = className + ".class" val result = new File(fileName) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala index 0ae0b4ec042e..891bcddeac28 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapred.InputSplit @@ -37,7 +37,7 @@ class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala index ec4f3964d75e..0f49279f3e64 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapreduce.InputSplit @@ -37,7 +37,7 @@ class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 8441bb3a3047..fb787979c182 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.java import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -142,7 +142,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions.asScala, seed)) /** * Return a subset of this RDD sampled by key (via stratified sampling). @@ -173,7 +173,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions.asScala, seed)) /** * ::Experimental:: @@ -768,7 +768,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. */ - def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key)) + def lookup(key: K): JList[V] = rdd.lookup(key).asJava /** Output the RDD to any Hadoop-supported file system. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( @@ -987,30 +987,27 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) object JavaPairRDD { private[spark] def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Iterable[T])]): RDD[(K, JIterable[T])] = { - rddToPairRDDFunctions(rdd).mapValues(asJavaIterable) + rddToPairRDDFunctions(rdd).mapValues(_.asJava) } private[spark] def cogroupResultToJava[K: ClassTag, V, W]( rdd: RDD[(K, (Iterable[V], Iterable[W]))]): RDD[(K, (JIterable[V], JIterable[W]))] = { - rddToPairRDDFunctions(rdd).mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava)) } private[spark] def cogroupResult2ToJava[K: ClassTag, V, W1, W2]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava)) } private[spark] def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => - (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava, x._4.asJava)) } def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index c582488f16fe..fc817cdd6a3f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,7 +21,6 @@ import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} import java.util.{Comparator, List => JList, Iterator => JIterator} -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -59,10 +58,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] @deprecated("Use partitions() instead.", "1.1.0") - def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + def splits: JList[Partition] = rdd.partitions.toSeq.asJava /** Set of partitions in this RDD. */ - def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + def partitions: JList[Partition] = rdd.partitions.toSeq.asJava /** The partitioner of this RDD. */ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) @@ -82,7 +81,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * subclasses of RDD. */ def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = - asJavaIterator(rdd.iterator(split, taskContext)) + rdd.iterator(split, taskContext).asJava // Transformations (return a new RDD) @@ -99,7 +98,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -153,7 +152,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -164,7 +163,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) @@ -175,7 +174,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } @@ -186,7 +185,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -197,7 +196,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) @@ -209,7 +208,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) @@ -219,14 +218,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to each partition of this RDD. */ def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) { - rdd.foreachPartition((x => f.call(asJavaIterator(x)))) + rdd.foreachPartition((x => f.call(x.asJava))) } /** * Return an RDD created by coalescing all elements within each partition into an array. */ def glom(): JavaRDD[JList[T]] = - new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + new JavaRDD(rdd.glom().map(_.toSeq.asJava)) /** * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of @@ -266,13 +265,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command)) + rdd.pipe(command.asScala) /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) + rdd.pipe(command.asScala, env.asScala) /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, @@ -294,8 +293,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { - (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).iterator().asScala } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) @@ -333,22 +331,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in this RDD. */ - def collect(): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.collect().toSeq - new java.util.ArrayList(arr) - } + def collect(): JList[T] = + rdd.collect().toSeq.asJava /** * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. */ - def toLocalIterator(): JIterator[T] = { - import scala.collection.JavaConversions._ - rdd.toLocalIterator - } - + def toLocalIterator(): JIterator[T] = + asJavaIteratorConverter(rdd.toLocalIterator).asJava /** * Return an array that contains all of the elements in this RDD. @@ -363,9 +355,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { // This is useful for implementing `take` from other language frontends // like Python where the data is serialized. - import scala.collection.JavaConversions._ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds) - res.map(x => new java.util.ArrayList(x.toSeq)).toArray + res.map(_.toSeq.asJava) } /** @@ -489,20 +480,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * it will be slow if a lot of partitions are required. In that case, use collect() to get the * whole RDD instead. */ - def take(num: Int): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.take(num).toSeq - new java.util.ArrayList(arr) - } + def take(num: Int): JList[T] = + rdd.take(num).toSeq.asJava def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq - new java.util.ArrayList(arr) - } + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = + rdd.takeSample(withReplacement, num, seed).toSeq.asJava /** * Return the first element in this RDD. @@ -582,10 +567,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * @return an array of top elements */ def top(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.top(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** @@ -607,10 +589,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * @return an array of top elements */ def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** @@ -696,7 +675,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { - new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)), + new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x.asJava)), { x => null.asInstanceOf[Void] }) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 02e49a853c5f..609496ccdfef 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,8 +21,7 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} -import scala.collection.JavaConversions -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -104,7 +103,7 @@ class JavaSparkContext(val sc: SparkContext) */ def this(master: String, appName: String, sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment, Map())) + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment.asScala, Map())) private[spark] val env = sc.env @@ -118,7 +117,7 @@ class JavaSparkContext(val sc: SparkContext) def appName: String = sc.appName - def jars: util.List[String] = sc.jars + def jars: util.List[String] = sc.jars.asJava def startTime: java.lang.Long = sc.startTime @@ -142,7 +141,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag - sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) + sc.parallelize(list.asScala, numSlices) } /** Get an RDD that has no partitions or elements. */ @@ -161,7 +160,7 @@ class JavaSparkContext(val sc: SparkContext) : JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[V] = fakeClassTag - JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) + JavaPairRDD.fromRDD(sc.parallelize(list.asScala, numSlices)) } /** Distribute a local Scala collection to form an RDD. */ @@ -170,8 +169,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = - JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()), - numSlices)) + JavaDoubleRDD.fromRDD(sc.parallelize(list.asScala.map(_.doubleValue()), numSlices)) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = @@ -519,7 +517,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { - val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[T]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[T] = first.classTag sc.union(rdds) } @@ -527,7 +525,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) : JavaPairRDD[K, V] = { - val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[(K, V)] = first.classTag implicit val ctagK: ClassTag[K] = first.kClassTag implicit val ctagV: ClassTag[V] = first.vClassTag @@ -536,7 +534,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { - val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd) + val rdds: Seq[RDD[Double]] = (Seq(first) ++ rest.asScala).map(_.srdd) new JavaDoubleRDD(sc.union(rdds)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index b959b683d167..a7dfa1d257cf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -17,15 +17,17 @@ package org.apache.spark.api.python -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, SparkException} +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ -import scala.util.{Failure, Success, Try} -import org.apache.spark.annotation.Experimental +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * :: Experimental :: @@ -68,7 +70,6 @@ private[python] class WritableToJavaConverter( * object representation */ private def convertWritable(writable: Writable): Any = { - import collection.JavaConversions._ writable match { case iw: IntWritable => iw.get() case dw: DoubleWritable => dw.get() @@ -89,9 +90,7 @@ private[python] class WritableToJavaConverter( aw.get().map(convertWritable(_)) case mw: MapWritable => val map = new java.util.HashMap[Any, Any]() - mw.foreach { case (k, v) => - map.put(convertWritable(k), convertWritable(v)) - } + mw.asScala.foreach { case (k, v) => map.put(convertWritable(k), convertWritable(v)) } map case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other @@ -122,7 +121,6 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { * supported out-of-the-box. */ private def convertToWritable(obj: Any): Writable = { - import collection.JavaConversions._ obj match { case i: java.lang.Integer => new IntWritable(i) case d: java.lang.Double => new DoubleWritable(d) @@ -134,7 +132,7 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { case null => NullWritable.get() case map: java.util.Map[_, _] => val mapWritable = new MapWritable() - map.foreach { case (k, v) => + map.asScala.foreach { case (k, v) => mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable @@ -161,9 +159,8 @@ private[python] object PythonHadoopUtil { * Convert a [[java.util.Map]] of properties to a [[org.apache.hadoop.conf.Configuration]] */ def mapToConf(map: java.util.Map[String, String]): Configuration = { - import collection.JavaConversions._ val conf = new Configuration() - map.foreach{ case (k, v) => conf.set(k, v) } + map.asScala.foreach { case (k, v) => conf.set(k, v) } conf } @@ -172,9 +169,8 @@ private[python] object PythonHadoopUtil { * any matching keys in left */ def mergeConfs(left: Configuration, right: Configuration): Configuration = { - import collection.JavaConversions._ val copy = new Configuration(left) - right.iterator().foreach(entry => copy.set(entry.getKey, entry.getValue)) + right.asScala.foreach(entry => copy.set(entry.getKey, entry.getValue)) copy } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 2a56bf28d702..b4d152b33660 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -21,7 +21,7 @@ import java.io._ import java.net._ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials @@ -66,11 +66,11 @@ private[spark] class PythonRDD( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { - envVars += ("SPARK_REUSE_WORKER" -> "1") + envVars.put("SPARK_REUSE_WORKER", "1") } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool @volatile var released = false @@ -150,7 +150,7 @@ private[spark] class PythonRDD( // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.toMap, worker) + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) released = true } } @@ -217,13 +217,13 @@ private[spark] class PythonRDD( // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { + dataOut.writeInt(pythonIncludes.size()) + for (include <- pythonIncludes.asScala) { PythonRDD.writeUTF(include, dataOut) } // Broadcast variables val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet + val newBids = broadcastVars.asScala.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) val cnt = toRemove.size + newBids.diff(oldBids).size @@ -233,7 +233,7 @@ private[spark] class PythonRDD( dataOut.writeLong(- bid - 1) // bid >= 0 oldBids.remove(bid) } - for (broadcast <- broadcastVars) { + for (broadcast <- broadcastVars.asScala) { if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) @@ -287,7 +287,7 @@ private[spark] class PythonRDD( if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap, worker) + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) @@ -358,10 +358,10 @@ private[spark] object PythonRDD extends Logging { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = - sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions) + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) serveIterator(flattenedPartition.iterator, - s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") + s"serve RDD ${rdd.id} with partitions ${partitions.asScala.mkString(",")}") } /** @@ -819,7 +819,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) - for (array <- val2) { + for (array <- val2.asScala) { out.writeInt(array.length) out.write(array) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 90dacaeb9342..31e534f160ee 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,10 +17,10 @@ package org.apache.spark.api.python -import java.io.{File} +import java.io.File import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext @@ -51,7 +51,14 @@ private[spark] object PythonUtils { * Convert list of T into seq of T (for calling API with varargs) */ def toSeq[T](vs: JList[T]): Seq[T] = { - vs.toList.toSeq + vs.asScala + } + + /** + * Convert list of T into a (Scala) List of T + */ + def toList[T](vs: JList[T]): List[T] = { + vs.asScala.toList } /** @@ -65,6 +72,6 @@ private[spark] object PythonUtils { * Convert java map of K, V into Map of K, V (for calling API with varargs) */ def toScalaMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = { - jm.toMap + jm.asScala.toMap } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index e314408c067e..7039b734d2e4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -19,9 +19,10 @@ package org.apache.spark.api.python import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.util.Arrays import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.util.{RedirectThread, Utils} @@ -108,9 +109,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") @@ -151,9 +152,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 1f1debcf84ad..fd27276e70bf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -22,7 +22,6 @@ import java.util.{ArrayList => JArrayList} import org.apache.spark.api.java.JavaRDD -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Failure @@ -214,7 +213,7 @@ private[spark] object SerDeUtil extends Logging { new AutoBatchedPickler(cleaned) } else { val pickle = new Pickler - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + cleaned.grouped(batchSize).map(batched => pickle.dumps(batched.asJava)) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index 8f30ff9202c8..ee1fb056f0d9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -20,6 +20,8 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} import java.{util => ju} +import scala.collection.JavaConverters._ + import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ @@ -62,10 +64,9 @@ private[python] class TestInputKeyConverter extends Converter[Any, Any] { } private[python] class TestInputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ override def convert(obj: Any): ju.List[Double] = { val m = obj.asInstanceOf[MapWritable] - seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) + m.keySet.asScala.map(_.asInstanceOf[DoubleWritable].get()).toSeq.asJava } } @@ -76,9 +77,8 @@ private[python] class TestOutputKeyConverter extends Converter[Any, Any] { } private[python] class TestOutputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ override def convert(obj: Any): DoubleWritable = { - new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) + new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().iterator().next()) } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 1cf2824f862e..9d5bbb5d609f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -19,9 +19,10 @@ package org.apache.spark.api.r import java.io._ import java.net.{InetAddress, ServerSocket} +import java.util.Arrays import java.util.{Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.io.Source import scala.reflect.ClassTag import scala.util.Try @@ -365,11 +366,11 @@ private[r] object RRDD { sparkConf.setIfMissing("spark.master", "local") } - for ((name, value) <- sparkEnvirMap) { - sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + for ((name, value) <- sparkEnvirMap.asScala) { + sparkConf.set(name.toString, value.toString) } - for ((name, value) <- sparkExecutorEnvMap) { - sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + for ((name, value) <- sparkExecutorEnvMap.asScala) { + sparkConf.setExecutorEnv(name.toString, value.toString) } val jsc = new JavaSparkContext(sparkConf) @@ -395,7 +396,7 @@ private[r] object RRDD { val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir + "/SparkR/worker/" + script - val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. // This is set by R CMD check as startup.Rs // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 427b2bc7cbcb..9e807cc52f18 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -18,8 +18,7 @@ package org.apache.spark.api.r import java.io.File - -import scala.collection.JavaConversions._ +import java.util.Arrays import org.apache.spark.{SparkEnv, SparkException} @@ -68,7 +67,7 @@ private[spark] object RUtils { /** Check if R is installed before running tests that use R commands. */ def isRInstalled: Boolean = { try { - val builder = new ProcessBuilder(Seq("R", "--version")) + val builder = new ProcessBuilder(Arrays.asList("R", "--version")) builder.start().waitFor() == 0 } catch { case e: Exception => false diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 3c89f2447374..dbbbcf40c1e9 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ /** * Utility functions to serialize, deserialize objects to / from R @@ -165,7 +165,7 @@ private[spark] object SerDe { val valueType = readObjectType(in) readTypedObject(in, valueType) }) - mapAsJavaMap(keys.zip(values).toMap) + keys.zip(values).toMap.asJava } else { new java.util.HashMap[Object, Object]() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index a0c9b5e63c74..7e3764d802fe 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -20,7 +20,7 @@ package org.apache.spark.broadcast import java.io._ import java.nio.ByteBuffer -import scala.collection.JavaConversions.asJavaEnumeration +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.Random @@ -210,7 +210,7 @@ private object TorrentBroadcast extends Logging { compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") val is = new SequenceInputStream( - asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) + blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 22ef701d833b..6840a3ae831f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -19,13 +19,13 @@ package org.apache.spark.deploy import java.util.concurrent.CountDownLatch -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslServerBootstrap -import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.server.{TransportServerBootstrap, TransportServer} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.util.TransportConf import org.apache.spark.util.Utils @@ -67,13 +67,13 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana def start() { require(server == null, "Shuffle server already started") logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - val bootstraps = + val bootstraps: Seq[TransportServerBootstrap] = if (useSasl) { Seq(new SaslServerBootstrap(transportConf, securityManager)) } else { Nil } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(port, bootstraps.asJava) } /** Clean up all shuffle files associated with an application that has exited. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 23d01e9cbb9f..d85327603f64 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -21,7 +21,7 @@ import java.net.URI import java.io.File import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Try import org.apache.spark.SparkUserAppException @@ -71,7 +71,7 @@ object PythonRunner { val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*) // Launch Python process - val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) + val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index ed1e97295567..4b28866dcaa7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -22,7 +22,7 @@ import java.util.jar.JarFile import java.util.logging.Level import java.util.zip.{ZipEntry, ZipOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.io.{ByteStreams, Files} @@ -110,7 +110,7 @@ private[deploy] object RPackageUtils extends Logging { print(s"Building R package with the command: $installCmd", printStream) } try { - val builder = new ProcessBuilder(installCmd) + val builder = new ProcessBuilder(installCmd.asJava) builder.redirectErrorStream(true) val env = builder.environment() env.clear() diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index c0cab22fa825..05b954ce3699 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io._ import java.util.concurrent.{Semaphore, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path @@ -68,7 +68,7 @@ object RRunner { if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { // Launch R val returnCode = try { - val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index b8d399354022..8d5e716e6aea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory} import org.apache.curator.retry.ExponentialBackoffRetry @@ -57,7 +57,7 @@ private[spark] object SparkCuratorUtil extends Logging { def deleteRecursive(zk: CuratorFramework, path: String) { if (zk.checkExists().forPath(path) != null) { - for (child <- zk.getChildren.forPath(path)) { + for (child <- zk.getChildren.forPath(path).asScala) { zk.delete().forPath(path + "/" + child) } zk.delete().forPath(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index dda4216c7efe..f7723ef5bde4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.util.{Arrays, Comparator} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal @@ -71,7 +71,7 @@ class SparkHadoopUtil extends Logging { } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { - for (token <- source.getTokens()) { + for (token <- source.getTokens.asScala) { dest.addToken(token) } } @@ -175,8 +175,8 @@ class SparkHadoopUtil extends Logging { } private def getFileSystemThreadStatistics(): Seq[AnyRef] = { - val stats = FileSystem.getAllStatistics() - stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + FileSystem.getAllStatistics.asScala.map( + Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { @@ -306,12 +306,13 @@ class SparkHadoopUtil extends Logging { val renewalInterval = sparkConf.getLong("spark.yarn.token.renewal.interval", (24 hours).toMillis) - credentials.getAllTokens.filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + credentials.getAllTokens.asScala + .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) .map { t => - val identifier = new DelegationTokenIdentifier() - identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) - (identifier.getIssueDate + fraction * renewalInterval).toLong - now - }.foldLeft(0L)(math.max) + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + (identifier.getIssueDate + fraction * renewalInterval).toLong - now + }.foldLeft(0L)(math.max) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 3f3c6627c21f..18a1c52ae53f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -23,7 +23,7 @@ import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source @@ -94,7 +94,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Set parameters from command line arguments try { - parse(args.toList) + parse(args.asJava) } catch { case e: IllegalArgumentException => SparkSubmit.printErrorAndExit(e.getMessage()) @@ -458,7 +458,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } override protected def handleExtraArgs(extra: JList[String]): Unit = { - childArgs ++= extra + childArgs ++= extra.asScala } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 563831cc6b8d..540e802420ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.curator.framework.CuratorFramework @@ -49,8 +49,8 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer } override def read[T: ClassTag](prefix: String): Seq[T] = { - val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) - file.map(deserializeFromFile[T]).flatten + zk.getChildren.forPath(WORKING_DIR).asScala + .filter(_.startsWith(prefix)).map(deserializeFromFile[T]).flatten } override def close() { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 45a3f4304543..ce02ee203a4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -18,9 +18,8 @@ package org.apache.spark.deploy.worker import java.io.{File, FileOutputStream, InputStream, IOException} -import java.lang.System._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import org.apache.spark.Logging @@ -62,7 +61,7 @@ object CommandUtils extends Logging { // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows val cmd = new WorkerCommandBuilder(sparkHome, memory, command).buildCommand() - cmd.toSeq ++ Seq(command.mainClass) ++ command.arguments + cmd.asScala ++ Seq(command.mainClass) ++ command.arguments } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index ec51c3d935d8..89159ff5e2b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -172,8 +172,8 @@ private[deploy] class DriverRunner( CommandUtils.redirectStream(process.getInputStream, stdout) val stderr = new File(baseDir, "stderr") - val header = "Launch Command: %s\n%s\n\n".format( - builder.command.mkString("\"", "\" \"", "\""), "=" * 40) + val formattedCommand = builder.command.asScala.mkString("\"", "\" \"", "\"") + val header = "Launch Command: %s\n%s\n\n".format(formattedCommand, "=" * 40) Files.append(header, stderr, UTF_8) CommandUtils.redirectStream(process.getErrorStream, stderr) } @@ -229,6 +229,6 @@ private[deploy] trait ProcessBuilderLike { private[deploy] object ProcessBuilderLike { def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { override def start(): Process = processBuilder.start() - override def command: Seq[String] = processBuilder.command() + override def command: Seq[String] = processBuilder.command().asScala } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index ab3fea475c2a..3aef0515cbf6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -129,7 +129,8 @@ private[deploy] class ExecutorRunner( val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() - logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) + val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") + logInfo(s"Launch command: $formattedCommand") builder.directory(executorDir) builder.environment.put("SPARK_EXECUTOR_DIRS", appLocalDirs.mkString(File.pathSeparator)) @@ -145,7 +146,7 @@ private[deploy] class ExecutorRunner( process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( - command.mkString("\"", "\" \"", "\""), "=" * 40) + formattedCommand, "=" * 40) // Redirect its stdout and stderr to files val stdout = new File(executorDir, "stdout") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 79b1536d9401..770927c80f7a 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -24,7 +24,6 @@ import java.util.{UUID, Date} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext import scala.util.Random diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 42a85e42ea2b..c3491bb8b1cf 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,7 +23,7 @@ import java.net.URL import java.nio.ByteBuffer import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal @@ -147,7 +147,7 @@ private[spark] class Executor( /** Returns the total amount of time this JVM process has spent in garbage collection. */ private def computeTotalGcTime(): Long = { - ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + ManagementFactory.getGarbageCollectorMXBeans.asScala.map(_.getCollectionTime).sum } class TaskRunner( @@ -425,7 +425,7 @@ private[spark] class Executor( val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() val curGCTime = computeTotalGcTime() - for (taskRunner <- runningTasks.values()) { + for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.foreach { metrics => metrics.updateShuffleReadMetrics() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 293c512f8b70..d16f4a1fc4e3 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import java.util.concurrent.ThreadPoolExecutor -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.hadoop.fs.FileSystem @@ -30,7 +30,7 @@ private[spark] class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends Source { private def fileStats(scheme: String) : Option[FileSystem.Statistics] = - FileSystem.getAllStatistics().find(s => s.getScheme.equals(scheme)) + FileSystem.getAllStatistics.asScala.find(s => s.getScheme.equals(scheme)) private def registerFileSystemStat[T]( scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index cfd672e1d8a9..0474fd2ccc12 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} @@ -28,7 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData} +import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData import org.apache.spark.util.{SignalLogger, Utils} private[spark] class MesosExecutorBackend @@ -55,7 +55,7 @@ private[spark] class MesosExecutorBackend slaveInfo: SlaveInfo) { // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend. - val cpusPerTask = executorInfo.getResourcesList + val cpusPerTask = executorInfo.getResourcesList.asScala .find(_.getName == "cpus") .map(_.getScalar.getValue.toInt) .getOrElse(0) diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 6cda7772f77b..a5ad47293f1c 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.input import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration @@ -44,12 +44,9 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum - - val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong super.setMaxSplitSize(maxSplitSize) } diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index aaef7c74eea3..1ba34a11414a 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -17,7 +17,7 @@ package org.apache.spark.input -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.InputSplit @@ -52,10 +52,8 @@ private[spark] class WholeTextFileInputFormat * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong super.setMaxSplitSize(maxSplitSize) diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index 9be98723aed1..0c096656f923 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.launcher import java.io.File import java.util.{HashMap => JHashMap, List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.deploy.Command @@ -32,7 +32,7 @@ import org.apache.spark.deploy.Command private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, command: Command) extends AbstractCommandBuilder { - childEnv.putAll(command.environment) + childEnv.putAll(command.environment.asJava) childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sparkHome) override def buildCommand(env: JMap[String, String]): JList[String] = { diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index d7495551ad23..dd2d325d8703 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -20,6 +20,7 @@ package org.apache.spark.metrics import java.io.{FileInputStream, InputStream} import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.matching.Regex @@ -58,25 +59,20 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging { propertyCategories = subProperties(properties, INSTANCE_REGEX) if (propertyCategories.contains(DEFAULT_PREFIX)) { - import scala.collection.JavaConversions._ - - val defaultProperty = propertyCategories(DEFAULT_PREFIX) - for { (inst, prop) <- propertyCategories - if (inst != DEFAULT_PREFIX) - (k, v) <- defaultProperty - if (prop.getProperty(k) == null) } { - prop.setProperty(k, v) + val defaultProperty = propertyCategories(DEFAULT_PREFIX).asScala + for((inst, prop) <- propertyCategories if (inst != DEFAULT_PREFIX); + (k, v) <- defaultProperty if (prop.get(k) == null)) { + prop.put(k, v) } } } def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = { val subProperties = new mutable.HashMap[String, Properties] - import scala.collection.JavaConversions._ - prop.foreach { kv => - if (regex.findPrefixOf(kv._1).isDefined) { - val regex(prefix, suffix) = kv._1 - subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2) + prop.asScala.foreach { kv => + if (regex.findPrefixOf(kv._1.toString).isDefined) { + val regex(prefix, suffix) = kv._1.toString + subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2.toString) } } subProperties diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index b089da8596e2..7c170a742fb6 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager @@ -55,7 +55,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator) + val streamId = streamManager.registerStream(blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index d650d5fe7308..ff8aae9ebe9f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,7 +17,7 @@ package org.apache.spark.network.netty -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} import org.apache.spark.{SecurityManager, SparkConf} @@ -58,7 +58,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage securityManager.isSaslEncryptionEnabled())) } transportContext = new TransportContext(transportConf, rpcHandler) - clientFactory = transportContext.createClientFactory(clientBootstrap.toList) + clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId logInfo("Server created on " + server.getPort) @@ -67,7 +67,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage /** Creates and binds the TransportServer, possibly trying multiple ports. */ private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { def startService(port: Int): (TransportServer, Int) = { - val server = transportContext.createServer(port, bootstraps) + val server = transportContext.createServer(port, bootstraps.asJava) (server, server.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 1499da07bb83..8d9ebadaf79d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -23,7 +23,7 @@ import java.nio.channels._ import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal @@ -145,7 +145,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, } def callOnExceptionCallbacks(e: Throwable) { - onExceptionCallbacks foreach { + onExceptionCallbacks.asScala.foreach { callback => try { callback(this, e) diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 91b07ce3af1b..5afce75680f9 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag @@ -48,9 +48,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf if (outputsMerged == totalOutputs) { val result = new JHashMap[T, BoundedDouble](sums.size) sums.foreach { case (key, sum) => - result(key) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(key, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -64,9 +64,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf val stdev = math.sqrt(variance) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(key) = new BoundedDouble(mean, confidence, low, high) + result.put(key, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala index af26c3d59ac0..a16404068480 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub while (iter.hasNext) { val entry = iter.next() val mean = entry.getValue.mean - result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) + result.put(entry.getKey, new BoundedDouble(mean, 1.0, mean, mean)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -72,9 +72,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub val confFactor = studentTCacher.get(counter.count) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala index 442fb86227d8..54a1beab3514 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl while (iter.hasNext) { val entry = iter.next() val sum = entry.getValue.sum - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(entry.getKey, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -80,9 +80,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl val confFactor = studentTCacher.get(counter.count) val low = sumEstimate - confFactor * sumStdev val high = sumEstimate + confFactor * sumStdev - result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(sumEstimate, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 326fafb230a4..4e5f2e8a5d46 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -22,7 +22,7 @@ import java.text.SimpleDateFormat import java.util.{Date, HashMap => JHashMap} import scala.collection.{Map, mutable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.util.DynamicVariable @@ -312,14 +312,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } : Iterator[JHashMap[K, V]] val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { - m2.foreach { pair => + m2.asScala.foreach { pair => val old = m1.get(pair._1) m1.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } m1 } : JHashMap[K, V] - self.mapPartitions(reducePartition).reduce(mergeMaps) + self.mapPartitions(reducePartition).reduce(mergeMaps).asScala } /** Alias for reduceByKeyLocally */ diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 3bb9998e1db4..afbe566b7656 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -23,7 +23,7 @@ import java.io.IOException import java.io.PrintWriter import java.util.StringTokenizer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -72,7 +72,7 @@ private[spark] class PipedRDD[T: ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[String] = { - val pb = new ProcessBuilder(command) + val pb = new ProcessBuilder(command.asJava) // Add the environmental variables to the process. val currentEnvVars = pb.environment() envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } @@ -81,7 +81,7 @@ private[spark] class PipedRDD[T: ClassTag]( // so the user code can access the input filename if (split.isInstanceOf[HadoopPartition]) { val hadoopSplit = split.asInstanceOf[HadoopPartition] - currentEnvVars.putAll(hadoopSplit.getPipeEnvVars()) + currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava) } // When spark.worker.separated.working.directory option is turned on, each diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index f7cb1791d4ac..9a4fa301b06e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -125,7 +125,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys integrate(1, t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + map.asScala.iterator.map(t => t._2.iterator.map((t._1, _))).flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index bac37bfdaa23..0e438ab4366d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.immutable.Set import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -107,7 +107,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl val retval = new ArrayBuffer[SplitInfo]() val list = instance.getSplits(job) - for (split <- list) { + for (split <- list.asScala) { retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 174b73221afc..5821afea9898 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging @@ -74,7 +74,7 @@ private[spark] class Pool( if (schedulableNameToSchedulable.containsKey(schedulableName)) { return schedulableNameToSchedulable.get(schedulableName) } - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { val sched = schedulable.getSchedulableByName(schedulableName) if (sched != null) { return sched @@ -84,12 +84,12 @@ private[spark] class Pool( } override def executorLost(executorId: String, host: String) { - schedulableQueue.foreach(_.executorLost(executorId, host)) + schedulableQueue.asScala.foreach(_.executorLost(executorId, host)) } override def checkSpeculatableTasks(): Boolean = { var shouldRevive = false - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { shouldRevive |= schedulable.checkSpeculatableTasks() } shouldRevive @@ -98,7 +98,7 @@ private[spark] class Pool( override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = { var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] val sortedSchedulableQueue = - schedulableQueue.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) + schedulableQueue.asScala.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) for (schedulable <- sortedSchedulableQueue) { sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d6e1e9e5bebc..452c32d5411c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import com.google.common.collect.HashBiMap @@ -233,7 +233,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers) { + for (offer <- offers.asScala) { val offerAttributes = toAttributeMap(offer.getAttributesList) val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) val slaveId = offer.getSlaveId.getValue @@ -251,21 +251,21 @@ private[spark] class CoarseMesosSchedulerBackend( val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse val taskId = newMesosTaskId() - taskIdToSlaveId(taskId) = slaveId + taskIdToSlaveId.put(taskId, slaveId) slaveIdsWithExecutors += slaveId coresByTaskId(taskId) = cpusToUse // Gather cpu resources from the available resources and use them in the task. val (remainingResources, cpuResourcesToUse) = partitionResources(offer.getResourcesList, "cpus", cpusToUse) val (_, memResourcesToUse) = - partitionResources(remainingResources, "mem", calculateTotalMemory(sc)) + partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) val taskBuilder = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) - .addAllResources(cpuResourcesToUse) - .addAllResources(memResourcesToUse) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil @@ -314,9 +314,9 @@ private[spark] class CoarseMesosSchedulerBackend( } if (TaskState.isFinished(TaskState.fromMesos(state))) { - val slaveId = taskIdToSlaveId(taskId) + val slaveId = taskIdToSlaveId.get(taskId) slaveIdsWithExecutors -= slaveId - taskIdToSlaveId -= taskId + taskIdToSlaveId.remove(taskId) // Remove the cores we have remembered for this task, if it's in the hashmap for (cores <- coresByTaskId.get(taskId)) { totalCoresAcquired -= cores @@ -361,7 +361,7 @@ private[spark] class CoarseMesosSchedulerBackend( stateLock.synchronized { if (slaveIdsWithExecutors.contains(slaveId)) { val slaveIdToTaskId = taskIdToSlaveId.inverse() - if (slaveIdToTaskId.contains(slaveId)) { + if (slaveIdToTaskId.containsKey(slaveId)) { val taskId: Int = slaveIdToTaskId.get(slaveId) taskIdToSlaveId.remove(taskId) removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) @@ -411,7 +411,7 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdToTaskId = taskIdToSlaveId.inverse() for (executorId <- executorIds) { val slaveId = executorId.split("/")(0) - if (slaveIdToTaskId.contains(slaveId)) { + if (slaveIdToTaskId.containsKey(slaveId)) { mesosDriver.killTask( TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) pendingRemovedSlaveIds += slaveId diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index 3efc536f1456..e0c547dce6d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.mesos -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode @@ -129,6 +129,6 @@ private[spark] class ZookeeperMesosClusterPersistenceEngine( } override def fetchAll[T](): Iterable[T] = { - zk.getChildren.forPath(WORKING_DIR).map(fetch[T]).flatten + zk.getChildren.forPath(WORKING_DIR).asScala.flatMap(fetch[T]) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1206f184fbc8..07da9242b992 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, Date, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -350,7 +350,7 @@ private[spark] class MesosClusterScheduler( } // TODO: Page the status updates to avoid trying to reconcile // a large amount of tasks at once. - driver.reconcileTasks(statuses) + driver.reconcileTasks(statuses.toSeq.asJava) } } } @@ -493,10 +493,10 @@ private[spark] class MesosClusterScheduler( } override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { - val currentOffers = offers.map { o => + val currentOffers = offers.asScala.map(o => new ResourceOffer( o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) - }.toList + ).toList logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() val currentTime = new Date() @@ -521,10 +521,10 @@ private[spark] class MesosClusterScheduler( currentOffers, tasks) } - tasks.foreach { case (offerId, tasks) => - driver.launchTasks(Collections.singleton(offerId), tasks) + tasks.foreach { case (offerId, taskInfos) => + driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - offers + offers.asScala .filter(o => !tasks.keySet.contains(o.getId)) .foreach(o => driver.declineOffer(o.getId)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 5c20606d5871..2e424054be78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File import java.util.{ArrayList => JArrayList, Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.{Scheduler => MScheduler, _} @@ -129,14 +129,12 @@ private[spark] class MesosSchedulerBackend( val (resourcesAfterCpu, usedCpuResources) = partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK) val (resourcesAfterMem, usedMemResources) = - partitionResources(resourcesAfterCpu, "mem", calculateTotalMemory(sc)) + partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) - builder.addAllResources(usedCpuResources) - builder.addAllResources(usedMemResources) + builder.addAllResources(usedCpuResources.asJava) + builder.addAllResources(usedMemResources.asJava) - sc.conf.getOption("spark.mesos.uris").map { uris => - setupUris(uris, command) - } + sc.conf.getOption("spark.mesos.uris").foreach(setupUris(_, command)) val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) @@ -148,7 +146,7 @@ private[spark] class MesosSchedulerBackend( .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder()) } - (executorInfo.build(), resourcesAfterMem) + (executorInfo.build(), resourcesAfterMem.asJava) } /** @@ -193,7 +191,7 @@ private[spark] class MesosSchedulerBackend( private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = { val builder = new StringBuilder - tasks.foreach { t => + tasks.asScala.foreach { t => builder.append("Task id: ").append(t.getTaskId.getValue).append("\n") .append("Slave id: ").append(t.getSlaveId.getValue).append("\n") .append("Task resources: ").append(t.getResourcesList).append("\n") @@ -211,7 +209,7 @@ private[spark] class MesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { // Fail-fast on offers we know will be rejected - val (usableOffers, unUsableOffers) = offers.partition { o => + val (usableOffers, unUsableOffers) = offers.asScala.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue @@ -323,10 +321,10 @@ private[spark] class MesosSchedulerBackend( .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) .setExecutor(executorInfo) .setName(task.name) - .addAllResources(cpuResources) + .addAllResources(cpuResources.asJava) .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) .build() - (taskInfo, finalResources) + (taskInfo, finalResources.asJava) } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 5b854aa5c275..860c8e097b3b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.util.{List => JList} import java.util.concurrent.CountDownLatch -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -137,7 +137,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { protected def getResource(res: JList[Resource], name: String): Double = { // A resource can have multiple values in the offer since it can either be from // a specific role or wildcard. - res.filter(_.getName == name).map(_.getScalar.getValue).sum + res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum } protected def markRegistered(): Unit = { @@ -169,7 +169,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { amountToUse: Double): (List[Resource], List[Resource]) = { var remain = amountToUse var requestedResources = new ArrayBuffer[Resource] - val remainingResources = resources.map { + val remainingResources = resources.asScala.map { case r => { if (remain > 0 && r.getType == Value.Type.SCALAR && @@ -214,7 +214,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return */ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { - offerAttributes.map(attr => { + offerAttributes.asScala.map(attr => { val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar case Value.Type.RANGES => attr.getRanges @@ -253,7 +253,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { requiredValues.map(_.toLong).exists(offerRange.contains(_)) case Some(offeredValue: Value.Set) => // check if the specified required values is a subset of offered set - requiredValues.subsetOf(offeredValue.getItemList.toSet) + requiredValues.subsetOf(offeredValue.getItemList.asScala.toSet) case Some(textValue: Value.Text) => // check if the specified value is equal, if multiple values are specified // we succeed if any of them match. @@ -299,14 +299,13 @@ private[mesos] trait MesosSchedulerUtils extends Logging { Map() } else { try { - Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map { - case (k, v) => - if (v == null || v.isEmpty) { - (k, Set[String]()) - } else { - (k, v.split(',').toSet) - } - } + splitter.split(constraintsVal).asScala.toMap.mapValues(v => + if (v == null || v.isEmpty) { + Set[String]() + } else { + v.split(',').toSet + } + ) } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0ff7562e912c..048a93850727 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -21,6 +21,7 @@ import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer import javax.annotation.Nullable +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException} @@ -373,16 +374,15 @@ private class JavaIterableWrapperSerializer override def read(kryo: Kryo, in: KryoInput, clz: Class[java.lang.Iterable[_]]) : java.lang.Iterable[_] = { kryo.readClassAndObject(in) match { - case scalaIterable: Iterable[_] => - scala.collection.JavaConversions.asJavaIterable(scalaIterable) - case javaIterable: java.lang.Iterable[_] => - javaIterable + case scalaIterable: Iterable[_] => scalaIterable.asJava + case javaIterable: java.lang.Iterable[_] => javaIterable } } } private object JavaIterableWrapperSerializer extends Logging { - // The class returned by asJavaIterable (scala.collection.convert.Wrappers$IterableWrapper). + // The class returned by JavaConverters.asJava + // (scala.collection.convert.Wrappers$IterableWrapper). val wrapperClass = scala.collection.convert.WrapAsJava.asJavaIterable(Seq(1)).getClass diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index f6a96d81e7aa..c057de9b3f4d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics @@ -210,11 +210,13 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) shuffleStates.get(shuffleId) match { case Some(state) => if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + for (fileGroup <- state.allFileGroups.asScala; + file <- fileGroup.files) { file.delete() } } else { - for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + for (mapId <- state.completedMapTasks.asScala; + reduceId <- 0 until state.numBuckets) { val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) blockManager.diskBlockManager.getFile(blockId).delete() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 6fec5240707a..7db6035553ae 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.immutable.HashSet import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} @@ -133,7 +133,7 @@ class BlockManagerMasterEndpoint( // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocks = blockLocations.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocks.foreach { blockId => val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) @@ -242,7 +242,7 @@ class BlockManagerMasterEndpoint( private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks) + new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) }.toArray } @@ -292,7 +292,7 @@ class BlockManagerMasterEndpoint( if (askSlaves) { info.slaveEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) } else { - Future { info.blocks.keys.filter(filter).toSeq } + Future { info.blocks.asScala.keys.filter(filter).toSeq } } future } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 78e7ddc27d1c..1738258a0c79 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import scala.collection.JavaConversions.mapAsJavaMap +import scala.collection.JavaConverters._ import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -92,7 +92,7 @@ private[spark] object AkkaUtils extends Logging { val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig .getOrElse(ConfigFactory.empty()) - val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]) + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap.asJava) .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString( s""" |akka.daemonic = on diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index a725767d08cc..13cb516b583e 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -19,12 +19,11 @@ package org.apache.spark.util import java.util.concurrent.CopyOnWriteArrayList -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.scheduler.SparkListener /** * An event bus which posts events to its listeners. @@ -46,7 +45,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * `postToAll` in the same thread for all events. */ final def postToAll(event: E): Unit = { - // JavaConversions will create a JIterableWrapper if we use some Scala collection functions. + // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here ewe use // Java Iterator directly. val iter = listeners.iterator @@ -69,7 +68,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass - listeners.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq + listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq } } diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 169489df6c1e..a1c33212cdb2 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -21,8 +21,6 @@ import java.net.{URLClassLoader, URL} import java.util.Enumeration import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions._ - /** * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. */ diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 8de75ba9a9c9..d7e5143c3095 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -21,7 +21,8 @@ import java.util.Set import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap -import scala.collection.{JavaConversions, mutable} +import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.Logging @@ -50,8 +51,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } def iterator: Iterator[(A, B)] = { - val jIterator = getEntrySet.iterator - JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) + getEntrySet.iterator.asScala.map(kv => (kv.getKey, kv.getValue.value)) } def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet @@ -90,9 +90,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { - JavaConversions.mapAsScalaConcurrentMap(internalMap) - .map { case (k, TimeStampedValue(v, t)) => (k, v) } - .filter(p) + internalMap.asScala.map { case (k, TimeStampedValue(v, t)) => (k, v) }.filter(p) } override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala index 7cd8f28b12dd..65efeb1f4c19 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions +import scala.collection.JavaConverters._ import scala.collection.mutable.Set private[spark] class TimeStampedHashSet[A] extends Set[A] { @@ -31,7 +31,7 @@ private[spark] class TimeStampedHashSet[A] extends Set[A] { def iterator: Iterator[A] = { val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(_.getKey) + jIterator.asScala.map(_.getKey) } override def + (elem: A): Set[A] = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 831331222671..2bab4af2e73a 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -25,7 +25,7 @@ import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -748,12 +748,12 @@ private[spark] object Utils extends Logging { // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order // on unix-like system. On windows, it returns in index order. // It's more proper to pick ip address following system output order. - val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse for (ni <- reOrderedNetworkIFs) { - val addresses = ni.getInetAddresses.toList - .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress) + val addresses = ni.getInetAddresses.asScala + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq if (addresses.nonEmpty) { val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) // because of Inet6Address.toHostName may add interface at the end if it knows about it @@ -1498,10 +1498,8 @@ private[spark] object Utils extends Logging { * properties which have been set explicitly, as well as those for which only a default value * has been defined. */ def getSystemProperties: Map[String, String] = { - val sysProps = for (key <- System.getProperties.stringPropertyNames()) yield - (key, System.getProperty(key)) - - sysProps.toMap + System.getProperties.stringPropertyNames().asScala + .map(key => (key, System.getProperty(key))).toMap } /** @@ -1812,7 +1810,8 @@ private[spark] object Utils extends Logging { try { val properties = new Properties() properties.load(inReader) - properties.stringPropertyNames().map(k => (k, properties(k).trim)).toMap + properties.stringPropertyNames().asScala.map( + k => (k, properties.getProperty(k).trim)).toMap } catch { case e: IOException => throw new SparkException(s"Failed when loading Spark properties from $filename", e) @@ -1941,7 +1940,8 @@ private[spark] object Utils extends Logging { return true } isBindCollision(e.getCause) - case e: MultiException => e.getThrowables.exists(isBindCollision) + case e: MultiException => + e.getThrowables.asScala.exists(isBindCollision) case e: Exception => isBindCollision(e.getCause) case _ => false } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index bdbca00a0062..4939b600dbfb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import scala.collection.JavaConversions.{collectionAsScalaIterable, asJavaIterator} +import scala.collection.JavaConverters._ import com.google.common.collect.{Ordering => GuavaOrdering} @@ -34,6 +34,6 @@ private[spark] object Utils { val ordering = new GuavaOrdering[T] { override def compare(l: T, r: T): Int = ord.compare(l, r) } - collectionAsScalaIterable(ordering.leastOf(asJavaIterator(input), num)).iterator + ordering.leastOf(input.asJava, num).iterator.asScala } } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index ffe4b4baffb2..ebd3d61ae732 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -24,10 +24,10 @@ import java.util.*; import java.util.concurrent.*; -import scala.collection.JavaConversions; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; +import scala.collection.JavaConverters; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -1473,7 +1473,9 @@ public Integer call(Integer v1, Integer v2) throws Exception { Assert.assertEquals(expected, results); Partitioner defaultPartitioner = Partitioner.defaultPartitioner( - combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.>newArrayList())); + combinedRDD.rdd(), + JavaConverters.collectionAsScalaIterableConverter( + Collections.>emptyList()).asScala().toSeq()); combinedRDD = originalRDD.keyBy(keyFunction) .combineByKey( createCombinerFunction, diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 90cb7da94e88..ff9a92cc0a42 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.util.concurrent.{TimeUnit, Executors} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.{Try, Random} @@ -148,7 +149,6 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } test("Thread safeness - SPARK-5425") { - import scala.collection.JavaConversions._ val executor = Executors.newSingleThreadScheduledExecutor() val sf = executor.scheduleAtFixedRate(new Runnable { override def run(): Unit = @@ -163,8 +163,9 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } finally { executor.shutdownNow() - for (key <- System.getProperties.stringPropertyNames() if key.startsWith("spark.5425.")) - System.getProperties.remove(key) + val sysProps = System.getProperties + for (key <- sysProps.stringPropertyNames().asScala if key.startsWith("spark.5425.")) + sysProps.remove(key) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index cbd2aee10c0e..86eb41dd7e5d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy import java.net.URL -import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 47a64081e297..1ed4bae3ca21 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -21,14 +21,14 @@ import java.io.{PrintStream, OutputStream, File} import java.net.URI import java.util.jar.Attributes.Name import java.util.jar.{JarFile, Manifest} -import java.util.zip.{ZipEntry, ZipFile} +import java.util.zip.ZipFile -import org.scalatest.BeforeAndAfterEach -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.api.r.RUtils @@ -142,7 +142,7 @@ class RPackageUtilsSuite extends SparkFunSuite with BeforeAndAfterEach { IvyTestUtils.writeFile(fakePackageDir, "DESCRIPTION", "abc") val finalZip = RPackageUtils.zipRLibraries(tempDir, "sparkr.zip") assert(finalZip.exists()) - val entries = new ZipFile(finalZip).entries().toSeq.map(_.getName) + val entries = new ZipFile(finalZip).entries().asScala.map(_.getName).toSeq assert(entries.contains("/test.R")) assert(entries.contains("/SparkR/abc.R")) assert(entries.contains("/SparkR/DESCRIPTION")) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index bed6f3ea6124..98664dc1101e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -19,8 +19,6 @@ package org.apache.spark.deploy.worker import java.io.File -import scala.collection.JavaConversions._ - import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} @@ -36,6 +34,7 @@ class ExecutorRunnerTest extends SparkFunSuite { ExecutorState.RUNNING) val builder = CommandUtils.buildProcessBuilder( appDesc.command, new SecurityManager(conf), 512, sparkHome, er.substituteVariables) - assert(builder.command().last === appId) + val builderCommand = builder.command() + assert(builderCommand.get(builderCommand.size() - 1) === appId) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 730535ece787..a9652d7e7d0b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import java.util.concurrent.Semaphore import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.scalatest.Matchers @@ -365,10 +365,9 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," + classOf[BasicJobCounter].getName) sc = new SparkContext(conf) - sc.listenerBus.listeners.collect { case x: BasicJobCounter => x}.size should be (1) - sc.listenerBus.listeners.collect { - case x: ListenerThatAcceptsSparkConf => x - }.size should be (1) + sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) + sc.listenerBus.listeners.asScala + .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 5ed30f64d705..319b3173e7a6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer -import java.util +import java.util.Arrays +import java.util.Collection import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -61,7 +62,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val resources = List( + val resources = Arrays.asList( mesosSchedulerBackend.createResource("cpus", 4), mesosSchedulerBackend.createResource("mem", 1024)) // uri is null. @@ -98,7 +99,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") val (execInfo, _) = backend.createExecutorInfo( - List(backend.createResource("cpus", 4)), "mockExecutor") + Arrays.asList(backend.createResource("cpus", 4)), "mockExecutor") assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) val portmaps = execInfo.getContainer.getDocker.getPortMappingsList assert(portmaps.get(0).getHostPort.equals(80)) @@ -179,7 +180,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) when( driver.launchTasks( Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), @@ -279,7 +280,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(1) - val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) when( driver.launchTasks( Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), @@ -304,7 +305,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi assert(cpusDev.getName.equals("cpus")) assert(cpusDev.getScalar.getValue.equals(1.0)) assert(cpusDev.getRole.equals("dev")) - val executorResources = taskInfo.getExecutor.getResourcesList + val executorResources = taskInfo.getExecutor.getResourcesList.asScala assert(executorResources.exists { r => r.getName.equals("mem") && r.getScalar.getValue.equals(484.0) && r.getRole.equals("prod") }) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 23a1fdb0f500..8d1c9d17e977 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag @@ -173,7 +174,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { test("asJavaIterable") { // Serialize a collection wrapped by asJavaIterable val ser = new KryoSerializer(conf).newInstance() - val a = ser.serialize(scala.collection.convert.WrapAsJava.asJavaIterable(Seq(12345))) + val a = ser.serialize(Seq(12345).asJava) val b = ser.deserialize[java.lang.Iterable[Int]](a) assert(b.iterator().next() === 12345) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 69888b2694ba..22e30ecaf053 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -21,7 +21,6 @@ import java.net.{HttpURLConnection, URL} import javax.servlet.http.{HttpServletResponse, HttpServletRequest} import scala.io.Source -import scala.collection.JavaConversions._ import scala.xml.Node import com.gargoylesoftware.htmlunit.DefaultCssErrorHandler @@ -341,15 +340,15 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B // The completed jobs table should have two rows. The first row will be the most recent job: val firstRow = find(cssSelector("tbody tr")).get.underlying val firstRowColumns = firstRow.findElements(By.tagName("td")) - firstRowColumns(0).getText should be ("1") - firstRowColumns(4).getText should be ("1/1 (2 skipped)") - firstRowColumns(5).getText should be ("8/8 (16 skipped)") + firstRowColumns.get(0).getText should be ("1") + firstRowColumns.get(4).getText should be ("1/1 (2 skipped)") + firstRowColumns.get(5).getText should be ("8/8 (16 skipped)") // The second row is the first run of the job, where nothing was skipped: val secondRow = findAll(cssSelector("tbody tr")).toSeq(1).underlying val secondRowColumns = secondRow.findElements(By.tagName("td")) - secondRowColumns(0).getText should be ("0") - secondRowColumns(4).getText should be ("3/3") - secondRowColumns(5).getText should be ("24/24") + secondRowColumns.get(0).getText should be ("0") + secondRowColumns.get(4).getText should be ("3/3") + secondRowColumns.get(5).getText should be ("24/24") } } } @@ -502,8 +501,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B for { (row, idx) <- rows.zipWithIndex columns = row.findElements(By.tagName("td")) - id = columns(0).getText() - description = columns(1).getText() + id = columns.get(0).getText() + description = columns.get(1).getText() } { id should be (expJobInfo(idx)._1) description should include (expJobInfo(idx)._2) @@ -547,8 +546,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B for { (row, idx) <- rows.zipWithIndex columns = row.findElements(By.tagName("td")) - id = columns(0).getText() - description = columns(1).getText() + id = columns.get(0).getText() + description = columns.get(1).getText() } { id should be (expStageInfo(idx)._1) description should include (expStageInfo(idx)._2) diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 36832f51d2ad..fa07c1e5017c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -19,10 +19,7 @@ package org.apache.spark.examples import java.nio.ByteBuffer - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ListBuffer -import scala.collection.immutable.Map +import java.util.Collections import org.apache.cassandra.hadoop.ConfigHelper import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat @@ -32,7 +29,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* @@ -121,12 +117,9 @@ object CassandraCQLTest { val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { - val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) - val outKey: java.util.Map[String, ByteBuffer] = outColFamKey - var outColFamVal = new ListBuffer[ByteBuffer] - outColFamVal += ByteBufferUtil.bytes(saleCount) - val outVal: java.util.List[ByteBuffer] = outColFamVal - (outKey, outVal) + val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId)) + val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount)) + (outKey, outVal) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 96ef3e198e38..2e56d24c60c3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -19,10 +19,9 @@ package org.apache.spark.examples import java.nio.ByteBuffer +import java.util.Arrays import java.util.SortedMap -import scala.collection.JavaConversions._ - import org.apache.cassandra.db.IColumn import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat import org.apache.cassandra.hadoop.ConfigHelper @@ -32,7 +31,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra @@ -118,7 +116,7 @@ object CassandraTest { val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) - val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + val mutations = Arrays.asList(new Mutation(), new Mutation()) mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(0).column_or_supercolumn.setColumn(colWord) mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index c42df2b8845d..bec61f3cd429 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils @@ -36,10 +36,10 @@ object DriverSubmissionTest { val properties = Utils.getSystemProperties println("Environment variables containing SPARK_TEST:") - env.filter{case (k, v) => k.contains("SPARK_TEST")}.foreach(println) + env.asScala.filter { case (k, _) => k.contains("SPARK_TEST")}.foreach(println) println("System properties containing spark.test:") - properties.filter{case (k, v) => k.toString.contains("spark.test")}.foreach(println) + properties.filter { case (k, _) => k.toString.contains("spark.test") }.foreach(println) for (i <- 1 until numSecondsToSleep) { println(s"Alive for $i out of $numSecondsToSleep seconds") diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 3ebb112fc069..805184e740f0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.pythonconverters import java.util.{Collection => JCollection, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.avro.generic.{GenericFixed, IndexedRecord} import org.apache.avro.mapred.AvroWrapper @@ -58,7 +58,7 @@ object AvroConversionUtil extends Serializable { val map = new java.util.HashMap[String, Any] obj match { case record: IndexedRecord => - record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + record.getSchema.getFields.asScala.zipWithIndex.foreach { case (f, i) => map.put(f.name, fromAvro(record.get(i), f.schema)) } case other => throw new SparkException( @@ -68,9 +68,9 @@ object AvroConversionUtil extends Serializable { } def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { - obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + obj.asInstanceOf[JMap[_, _]].asScala.map { case (key, value) => (key.toString, fromAvro(value, schema.getValueType)) - } + }.asJava } def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { @@ -91,17 +91,17 @@ object AvroConversionUtil extends Serializable { def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { case c: JCollection[_] => - c.map(fromAvro(_, schema.getElementType)) + c.asScala.map(fromAvro(_, schema.getElementType)).toSeq.asJava case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => - arr.toSeq + arr.toSeq.asJava.asInstanceOf[JCollection[Any]] case arr: Array[_] => - arr.map(fromAvro(_, schema.getElementType)).toSeq + arr.map(fromAvro(_, schema.getElementType)).toSeq.asJava case other => throw new SparkException( s"Unknown ARRAY type ${other.getClass.getName}") } def unpackUnion(obj: Any, schema: Schema): Any = { - schema.getTypes.toList match { + schema.getTypes.asScala.toList match { case List(s) => fromAvro(obj, s) case List(n, s) if n.getType == NULL => fromAvro(obj, s) case List(s, n) if n.getType == NULL => fromAvro(obj, s) diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala index 83feb5703b90..00ce47af4813 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala @@ -17,11 +17,13 @@ package org.apache.spark.examples.pythonconverters -import org.apache.spark.api.python.Converter import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + import org.apache.cassandra.utils.ByteBufferUtil -import collection.JavaConversions._ +import org.apache.spark.api.python.Converter /** * Implementation of [[org.apache.spark.api.python.Converter]] that converts Cassandra @@ -30,7 +32,7 @@ import collection.JavaConversions._ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int]] { override def convert(obj: Any): java.util.Map[String, Int] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.toInt(bb))) + result.asScala.mapValues(ByteBufferUtil.toInt).asJava } } @@ -41,7 +43,7 @@ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int] class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, String]] { override def convert(obj: Any): java.util.Map[String, String] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb))) + result.asScala.mapValues(ByteBufferUtil.string).asJava } } @@ -52,7 +54,7 @@ class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, St class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, ByteBuffer]] { override def convert(obj: Any): java.util.Map[String, ByteBuffer] = { val input = obj.asInstanceOf[java.util.Map[String, Int]] - mapAsJavaMap(input.mapValues(i => ByteBufferUtil.bytes(i))) + input.asScala.mapValues(ByteBufferUtil.bytes).asJava } } @@ -63,6 +65,6 @@ class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, By class ToCassandraCQLValueConverter extends Converter[Any, java.util.List[ByteBuffer]] { override def convert(obj: Any): java.util.List[ByteBuffer] = { val input = obj.asInstanceOf[java.util.List[String]] - seqAsJavaList(input.map(s => ByteBufferUtil.bytes(s))) + input.asScala.map(ByteBufferUtil.bytes).asJava } } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 90d48a64106c..0a25ee7ae56f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.pythonconverters -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.parsing.json.JSONObject import org.apache.spark.api.python.Converter @@ -33,7 +33,6 @@ import org.apache.hadoop.hbase.CellUtil */ class HBaseResultToStringConverter extends Converter[Any, String] { override def convert(obj: Any): String = { - import collection.JavaConverters._ val result = obj.asInstanceOf[Result] val output = result.listCells.asScala.map(cell => Map( @@ -77,7 +76,7 @@ class StringToImmutableBytesWritableConverter extends Converter[Any, ImmutableBy */ class StringListToPutConverter extends Converter[Any, Put] { override def convert(obj: Any): Put = { - val output = obj.asInstanceOf[java.util.ArrayList[String]].map(Bytes.toBytes(_)).toArray + val output = obj.asInstanceOf[java.util.ArrayList[String]].asScala.map(Bytes.toBytes).toArray val put = new Put(output(0)) put.add(output(1), output(2), output(3)) } diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index fa43629d4977..d2654700ea72 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -20,7 +20,7 @@ import java.net.InetSocketAddress import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} @@ -166,7 +166,7 @@ class SparkSinkSuite extends FunSuite { channelContext.put("capacity", channelCapacity.toString) channelContext.put("transactionCapacity", 1000.toString) channelContext.put("keep-alive", 0.toString) - channelContext.putAll(overrides) + channelContext.putAll(overrides.asJava) channel.setName(scala.util.Random.nextString(10)) channel.configure(channelContext) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index 65c49c131518..48df27b26867 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.flume import java.io.{ObjectOutput, ObjectInput} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils import org.apache.spark.Logging @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k, v) <- headers) { + for ((k, v) <- headers.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala index 88cc2aa3bf02..b9d4e762ca05 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -155,7 +154,7 @@ private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends R val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) var j = 0 while (j < events.size()) { - val event = events(j) + val event = events.get(j) val sparkFlumeEvent = new SparkFlumeEvent() sparkFlumeEvent.event.setBody(event.getBody) sparkFlumeEvent.event.setHeaders(event.getHeaders) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 1e32a365a1ee..2bf99cb3cba1 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -22,7 +22,7 @@ import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer import java.util.concurrent.Executors -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.flume.source.avro.AvroSourceProtocol @@ -99,7 +99,7 @@ class SparkFlumeEvent() extends Externalizable { val numHeaders = event.getHeaders.size() out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders) { + for ((k, v) <- event.getHeaders.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) @@ -127,8 +127,7 @@ class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { } override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { - events.foreach (event => - receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) + events.asScala.foreach(event => receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) Status.OK } } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 583e7dca317a..0bc46209b836 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import java.util.concurrent.{LinkedBlockingQueue, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -94,9 +94,7 @@ private[streaming] class FlumePollingReceiver( override def onStop(): Unit = { logInfo("Shutting down Flume Polling Receiver") receiverExecutor.shutdownNow() - connections.foreach(connection => { - connection.transceiver.close() - }) + connections.asScala.foreach(_.transceiver.close()) channelFactory.releaseExternalResources() } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala index 9d9c3b189415..70018c86f92b 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -19,9 +19,9 @@ package org.apache.spark.streaming.flume import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer -import java.util.{List => JList} +import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import org.apache.avro.ipc.NettyTransceiver @@ -59,13 +59,13 @@ private[flume] class FlumeTestUtils { } /** Send data to the flume receiver */ - def writeInput(input: JList[String], enableCompression: Boolean): Unit = { + def writeInput(input: Seq[String], enableCompression: Boolean): Unit = { val testAddress = new InetSocketAddress("localhost", testPort) val inputEvents = input.map { item => val event = new AvroFlumeEvent event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + event.setHeaders(Collections.singletonMap("test", "header")) event } @@ -88,7 +88,7 @@ private[flume] class FlumeTestUtils { } // Send data - val status = client.appendBatch(inputEvents.toList) + val status = client.appendBatch(inputEvents.asJava) if (status != avro.Status.OK) { throw new AssertionError("Sent events unsuccessfully") } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index a65a9b921aaf..c719b80aca7e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -21,7 +21,7 @@ import java.net.InetSocketAddress import java.io.{DataOutputStream, ByteArrayOutputStream} import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.api.java.function.PairFunction import org.apache.spark.api.python.PythonRDD @@ -268,8 +268,8 @@ private[flume] class FlumeUtilsPythonHelper { maxBatchSize: Int, parallelism: Int ): JavaPairDStream[Array[Byte], Array[Byte]] = { - assert(hosts.length == ports.length) - val addresses = hosts.zip(ports).map { + assert(hosts.size() == ports.size()) + val addresses = hosts.asScala.zip(ports.asScala).map { case (host, port) => new InetSocketAddress(host, port) } val dstream = FlumeUtils.createPollingStream( @@ -286,7 +286,7 @@ private object FlumeUtilsPythonHelper { val output = new DataOutputStream(byteStream) try { output.writeInt(map.size) - map.foreach { kv => + map.asScala.foreach { kv => PythonRDD.writeUTF(kv._1.toString, output) PythonRDD.writeUTF(kv._2.toString, output) } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index 91d63d49dbec..a2ab320957db 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -18,9 +18,8 @@ package org.apache.spark.streaming.flume import java.util.concurrent._ -import java.util.{List => JList, Map => JMap} +import java.util.{Map => JMap, Collections} -import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Charsets.UTF_8 @@ -77,7 +76,7 @@ private[flume] class PollingFlumeTestUtils { /** * Start 2 sinks and return the ports */ - def startMultipleSinks(): JList[Int] = { + def startMultipleSinks(): Seq[Int] = { channels.clear() sinks.clear() @@ -138,8 +137,7 @@ private[flume] class PollingFlumeTestUtils { /** * A Python-friendly method to assert the output */ - def assertOutput( - outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { + def assertOutput(outputHeaders: Seq[JMap[String, String]], outputBodies: Seq[String]): Unit = { require(outputHeaders.size == outputBodies.size) val eventSize = outputHeaders.size if (eventSize != totalEventsPerChannel * channels.size) { @@ -149,12 +147,12 @@ private[flume] class PollingFlumeTestUtils { var counter = 0 for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { val eventBodyToVerify = s"${channels(k).getName}-$i" - val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") + val eventHeaderToVerify: JMap[String, String] = Collections.singletonMap(s"test-$i", "header") var found = false var j = 0 while (j < eventSize && !found) { - if (eventBodyToVerify == outputBodies.get(j) && - eventHeaderToVerify == outputHeaders.get(j)) { + if (eventBodyToVerify == outputBodies(j) && + eventHeaderToVerify == outputHeaders(j)) { found = true counter += 1 } @@ -195,7 +193,7 @@ private[flume] class PollingFlumeTestUtils { tx.begin() for (j <- 0 until eventsPerBatch) { channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), - Map[String, String](s"test-$t" -> "header"))) + Collections.singletonMap(s"test-$t", "header"))) t += 1 } tx.commit() diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index d5f9a0aa38f9..ff2fb8eed204 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps @@ -116,9 +116,9 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log // The eventually is required to ensure that all data in the batch has been processed. eventually(timeout(10 seconds), interval(100 milliseconds)) { val flattenOutputBuffer = outputBuffer.flatten - val headers = flattenOutputBuffer.map(_.event.getHeaders.map { - case kv => (kv._1.toString, kv._2.toString) - }).map(mapAsJavaMap) + val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { + case (key, value) => (key.toString, value.toString) + }).map(_.asJava) val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) utils.assertOutput(headers, bodies) } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 5bc4cdf65306..5ffb60bd602f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 79a9db4291be..c9fd715d3d55 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -24,6 +24,7 @@ import java.util.concurrent.TimeoutException import java.util.{Map => JMap, Properties} import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.language.postfixOps import scala.util.control.NonFatal @@ -159,8 +160,7 @@ private[kafka] class KafkaTestUtils extends Logging { /** Java-friendly function for sending messages to the Kafka broker */ def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { - import scala.collection.JavaConversions._ - sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*)) + sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) } /** Send the messages to the Kafka broker */ diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 388dbb818410..312822207753 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import java.lang.{Integer => JInt, Long => JLong} import java.util.{List => JList, Map => JMap, Set => JSet} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import kafka.common.TopicAndPartition @@ -96,7 +96,7 @@ object KafkaUtils { groupId: String, topics: JMap[String, JInt] ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*)) } /** @@ -115,7 +115,7 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), storageLevel) } @@ -149,7 +149,10 @@ object KafkaUtils { implicit val valueCmd: ClassTag[T] = ClassTag(valueDecoderClass) createStream[K, V, U, T]( - jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + jssc.ssc, + kafkaParams.asScala.toMap, + Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), + storageLevel) } /** get leaders for the given offset ranges, or throw an exception */ @@ -275,7 +278,7 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) new JavaPairRDD(createRDD[K, V, KD, VD]( - jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges)) + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges)) } /** @@ -311,9 +314,9 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) - val leaderMap = Map(leaders.toSeq: _*) + val leaderMap = Map(leaders.asScala.toSeq: _*) createRDD[K, V, KD, VD, R]( - jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges, leaderMap, messageHandler.call _) + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges, leaderMap, messageHandler.call(_)) } /** @@ -476,8 +479,8 @@ object KafkaUtils { val cleanedHandler = jssc.sparkContext.clean(messageHandler.call _) createDirectStream[K, V, KD, VD, R]( jssc.ssc, - Map(kafkaParams.toSeq: _*), - Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), + Map(kafkaParams.asScala.toSeq: _*), + Map(fromOffsets.asScala.mapValues(_.longValue()).toSeq: _*), cleanedHandler ) } @@ -531,8 +534,8 @@ object KafkaUtils { implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) createDirectStream[K, V, KD, VD]( jssc.ssc, - Map(kafkaParams.toSeq: _*), - Set(topics.toSeq: _*) + Map(kafkaParams.asScala.toSeq: _*), + Set(topics.asScala.toSeq: _*) ) } } @@ -602,10 +605,10 @@ private[kafka] class KafkaUtilsPythonHelper { ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { if (!fromOffsets.isEmpty) { - import scala.collection.JavaConversions._ - val topicsFromOffsets = fromOffsets.keySet().map(_.topic) - if (topicsFromOffsets != topics.toSet) { - throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic) + if (topicsFromOffsets != topics.asScala.toSet) { + throw new IllegalStateException( + s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") } } @@ -663,6 +666,6 @@ private[kafka] class KafkaUtilsPythonHelper { "with this RDD, please call this method only on a Kafka RDD.") val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] - kafkaRDD.offsetRanges.toSeq + kafkaRDD.offsetRanges.toSeq.asJava } } diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index 0469d0af8864..4ea218eaa4de 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -18,15 +18,17 @@ package org.apache.spark.streaming.zeromq import scala.reflect.ClassTag -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ + import akka.actor.{Props, SupervisorStrategy} import akka.util.ByteString import akka.zeromq.Subscribe + import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream} +import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.ActorSupervisorStrategy object ZeroMQUtils { @@ -75,7 +77,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel, supervisorStrategy) } @@ -99,7 +102,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel) } @@ -122,7 +126,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index a003ddf325e6..5d32fa699ae5 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.kinesis -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} @@ -213,7 +213,7 @@ class KinesisSequenceRangeIterator( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) } - (getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator) + (getRecordsResult.getRecords.iterator().asScala, getRecordsResult.getNextShardIterator) } /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 22324e821ce9..6e0988c1af8a 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kinesis import java.util.UUID -import scala.collection.JavaConversions.asScalaIterator +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal @@ -202,7 +202,7 @@ private[kinesis] class KinesisReceiver( /** Add records of the given shard to the current block being generated */ private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = { if (records.size > 0) { - val dataIterator = records.iterator().map { record => + val dataIterator = records.iterator().asScala.map { record => val byteBuffer = record.getData() val byteArray = new Array[Byte](byteBuffer.remaining()) byteBuffer.get(byteArray) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index c8eec13ec7dc..634bf9452107 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} @@ -115,7 +116,7 @@ private[kinesis] class KinesisTestUtils extends Logging { * Expose a Python friendly API. */ def pushData(testData: java.util.List[Int]): Unit = { - pushData(scala.collection.JavaConversions.asScalaBuffer(testData)) + pushData(testData.asScala) } def deleteStream(): Unit = { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index ceb135e0651a..3d136aec2e70 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.Arrays -import scala.collection.JavaConversions.seqAsJavaList - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record @@ -47,10 +47,10 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft val someSeqNum = Some(seqNum) val record1 = new Record() - record1.setData(ByteBuffer.wrap("Spark In Action".getBytes())) + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8))) val record2 = new Record() - record2.setData(ByteBuffer.wrap("Learning Spark".getBytes())) - val batch = List[Record](record1, record2) + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) + val batch = Arrays.asList(record1, record2) var receiverMock: KinesisReceiver = _ var checkpointerMock: IRecordProcessorCheckpointer = _ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 87eeb5db05d2..7a1c7796065e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.util -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -52,7 +52,7 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps)) + generateLinearInput(intercept, weights, nPoints, seed, eps).asJava } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index a1ee55415237..2744e020e9e4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -20,7 +20,7 @@ import java.io.Serializable; import java.util.List; -import static scala.collection.JavaConversions.seqAsJavaList; +import scala.collection.JavaConverters; import org.junit.After; import org.junit.Assert; @@ -55,8 +55,9 @@ public void setUp() { double[] xMean = {5.843, 3.057, 3.758, 1.199}; double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List points = seqAsJavaList(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42)); + List points = JavaConverters.asJavaListConverter( + generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + ).asJava(); datasetRDD = jsc.parallelize(points, 2); dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 2473510e1351..8d14bb657215 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import scala.util.control.Breaks._ @@ -38,7 +38,7 @@ object LogisticRegressionSuite { scale: Double, nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed)) + generateLogisticInput(offset, scale, nPoints, seed).asJava } // Generate input of the form Y = logistic(offset + scale*X) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index b1d78cba9e3d..ee3c85d09a46 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import org.jblas.DoubleMatrix @@ -35,7 +35,7 @@ object SVMSuite { weights: Array[Double], nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed)) + generateSVMInput(intercept, weights, nPoints, seed).asJava } // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 13b754a03943..36ac7d267243 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.optimization -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import org.scalatest.Matchers @@ -35,7 +35,7 @@ object GradientDescentSuite { scale: Double, nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateGDInput(offset, scale, nPoints, seed)) + generateGDInput(offset, scale, nPoints, seed).asJava } // Generate input of the form Y = logistic(offset + scale * X) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 05b87728d6fd..045135f7f8d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.recommendation -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.math.abs import scala.util.Random @@ -38,7 +38,7 @@ object ALSSuite { negativeWeights: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = { val (sampledRatings, trueRatings, truePrefs) = generateRatings(users, products, features, samplingRate, implicitPrefs) - (seqAsJavaList(sampledRatings), trueRatings, truePrefs) + (sampledRatings.asJava, trueRatings, truePrefs) } def generateRatings( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 04e0d49b178c..ea52bfd67944 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -18,13 +18,13 @@ import java.io._ import scala.util.Properties -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion -import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} +import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings import spray.revolver.RevolverPlugin._ @@ -120,7 +120,7 @@ object SparkBuild extends PomBuild { case _ => } - override val userPropertiesMap = System.getProperties.toMap + override val userPropertiesMap = System.getProperties.asScala.toMap lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") @@ -559,7 +559,7 @@ object TestSettings { javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", - javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") + javaOptions in Test ++= System.getProperties.asScala.filter(_._1.startsWith("spark")) .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 8af8637cf948..0948f9b27cd3 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -61,6 +61,18 @@ def _to_seq(sc, cols, converter=None): return sc._jvm.PythonUtils.toSeq(cols) +def _to_list(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM (Scala) List of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toList(cols) + + def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 025811f51929..e269ef4304f3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -32,7 +32,7 @@ from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql import since from pyspark.sql.types import _parse_datatype_json_string -from pyspark.sql.column import Column, _to_seq, _to_java_column +from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * @@ -494,7 +494,7 @@ def randomSplit(self, weights, seed=None): if w < 0.0: raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) - rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) + rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed)) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @property diff --git a/scalastyle-config.xml b/scalastyle-config.xml index b5e2e882d225..68fdb4141cf2 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -161,6 +161,13 @@ This file is divided into 3 sections: ]]> + + + JavaConversions + Instead of importing implicits in scala.collection.JavaConversions._, import + scala.collection.JavaConverters._ and use .asScala / .asJava methods + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index ec895af9c303..cfd9cb0e6259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -280,9 +282,8 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getList[T](i: Int): java.util.List[T] = { - scala.collection.JavaConversions.seqAsJavaList(getSeq[T](i)) - } + def getList[T](i: Int): java.util.List[T] = + getSeq[T](i).asJava /** * Returns the value at position i of map type as a Scala Map. @@ -296,9 +297,8 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getJavaMap[K, V](i: Int): java.util.Map[K, V] = { - scala.collection.JavaConversions.mapAsJavaMap(getMap[K, V](i)) - } + def getJavaMap[K, V](i: Int): java.util.Map[K, V] = + getMap[K, V](i).asJava /** * Returns the value at position i of struct type as an [[Row]] object. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 503c4f4b20f3..4cc9a5520a08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -147,7 +147,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { val result = ArrayBuffer.empty[(String, Boolean)] - for (name <- tables.keySet()) { + for (name <- tables.keySet().asScala) { result += ((name, true)) } result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index a4fd4cf3b330..77a42c0873a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.{lang => jl} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.expressions._ @@ -209,7 +209,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.asScala.toSeq) /** * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. @@ -254,7 +254,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = { - replace[T](col, replacement.toMap : Map[T, T]) + replace[T](col, replacement.asScala.toMap) } /** @@ -277,7 +277,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { - replace(cols.toSeq, replacement.toMap) + replace(cols.toSeq, replacement.asScala.toMap) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6dc7bfe33349..97a8b6518a83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Experimental @@ -90,7 +92,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameReader = { - this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this.options(options.asScala) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ce8744b53175..b2a66dd417b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -109,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameWriter = { - this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this.options(options.asScala) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 99d557b03a03..ee31d83cce42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental @@ -188,7 +188,7 @@ class GroupedData protected[sql]( * @since 1.3.0 */ def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.toMap) + agg(exprs.asScala.toMap) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e9de14f02550..e6f7619519e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.util.Properties import scala.collection.immutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.parquet.hadoop.ParquetOutputCommitter @@ -531,7 +531,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { /** Set Spark SQL configuration properties. */ def setConf(props: Properties): Unit = settings.synchronized { - props.foreach { case (k, v) => setConfString(k, v) } + props.asScala.foreach { case (k, v) => setConfString(k, v) } } /** Set the given Spark SQL configuration property using a `string` value. */ @@ -601,24 +601,25 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * Return all the configuration properties that have been set (i.e. not the default). * This creates a new copy of the config properties in the form of a Map. */ - def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap } + def getAllConfs: immutable.Map[String, String] = + settings.synchronized { settings.asScala.toMap } /** * Return all the configuration definitions that have been defined in [[SQLConf]]. Each * definition contains key, defaultValue and doc. */ def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { - sqlConfEntries.values.filter(_.isPublic).map { entry => + sqlConfEntries.values.asScala.filter(_.isPublic).map { entry => (entry.key, entry.defaultValueString, entry.doc) }.toSeq } private[spark] def unsetConf(key: String): Unit = { - settings -= key + settings.remove(key) } private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = { - settings -= entry.key + settings.remove(entry.key) } private[spark] def clear(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a1eea09e0477..4e8414af50b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,7 +21,7 @@ import java.beans.Introspector import java.util.Properties import java.util.concurrent.atomic.AtomicReference -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -225,7 +225,7 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.setConf(properties) // After we have populated SQLConf, we call setConf to populate other confs in the subclass // (e.g. hiveconf in HiveContext). - properties.foreach { + properties.asScala.foreach { case (key, value) => setConf(key, value) } } @@ -567,7 +567,7 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, options.toMap) + createExternalTable(tableName, source, options.asScala.toMap) } /** @@ -612,7 +612,7 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, schema, options.toMap) + createExternalTable(tableName, source, schema, options.asScala.toMap) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 8fbaf3a3059d..011724436621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.util.ServiceLoader -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Success, Failure, Try} @@ -55,7 +55,7 @@ object ResolvedDataSource extends Logging { val loader = Utils.getContextOrSparkClassLoader val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) - serviceLoader.iterator().filter(_.shortName().equalsIgnoreCase(provider)).toList match { + serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { /** the provider format did not match any given registered aliases */ case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { case Success(dataSource) => dataSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 3f8353af6e2a..0a6bb44445f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.{Map => JMap} -import scala.collection.JavaConversions.{iterableAsScalaIterable, mapAsJavaMap, mapAsScalaMap} +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.parquet.hadoop.api.ReadSupport.ReadContext @@ -44,7 +44,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with val parquetRequestedSchema = readContext.getRequestedSchema val catalystRequestedSchema = - Option(readContext.getReadSupportMetadata).map(_.toMap).flatMap { metadata => + Option(readContext.getReadSupportMetadata).map(_.asScala).flatMap { metadata => metadata // First tries to read requested schema, which may result from projections .get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) @@ -123,7 +123,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with maybeRequestedSchema.fold(context.getFileSchema) { schemaString => val toParquet = new CatalystSchemaConverter(conf) val fileSchema = context.getFileSchema.asGroupType() - val fileFieldNames = fileSchema.getFields.map(_.getName).toSet + val fileFieldNames = fileSchema.getFields.asScala.map(_.getName).toSet StructType // Deserializes the Catalyst schema of requested columns @@ -152,7 +152,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with maybeRequestedSchema.map(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) - new ReadContext(parquetRequestedSchema, metadata) + new ReadContext(parquetRequestedSchema, metadata.asJava) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index cbf0704c4a9a..f682ca0d8ff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary @@ -183,7 +183,7 @@ private[parquet] class CatalystRowConverter( // those missing fields and create converters for them, although values of these fields are // always null. val paddedParquetFields = { - val parquetFields = parquetType.getFields + val parquetFields = parquetType.getFields.asScala val parquetFieldNames = parquetFields.map(_.getName).toSet val missingFields = catalystType.filterNot(f => parquetFieldNames.contains(f.name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 535f0684e97f..be6c0545f5a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.parquet.schema.OriginalType._ @@ -82,7 +82,7 @@ private[parquet] class CatalystSchemaConverter( def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) private def convert(parquetSchema: GroupType): StructType = { - val fields = parquetSchema.getFields.map { field => + val fields = parquetSchema.getFields.asScala.map { field => field.getRepetition match { case OPTIONAL => StructField(field.getName, convertField(field), nullable = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index bbf682aec0f9..64982f37cf87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -21,7 +21,7 @@ import java.net.URI import java.util.logging.{Logger => JLogger} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.{Failure, Try} @@ -336,7 +336,7 @@ private[sql] class ParquetRelation( override def getPartitions: Array[SparkPartition] = { val inputFormat = new ParquetInputFormat[InternalRow] { override def listStatus(jobContext: JobContext): JList[FileStatus] = { - if (cacheMetadata) cachedStatuses else super.listStatus(jobContext) + if (cacheMetadata) cachedStatuses.asJava else super.listStatus(jobContext) } } @@ -344,7 +344,8 @@ private[sql] class ParquetRelation( val rawSplits = inputFormat.getSplits(jobContext) Array.tabulate[SparkPartition](rawSplits.size) { i => - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + new SqlNewHadoopPartition( + id, i, rawSplits.get(i).asInstanceOf[InputSplit with Writable]) } } }.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] @@ -588,7 +589,7 @@ private[sql] object ParquetRelation extends Logging { val metadata = footer.getParquetMetadata.getFileMetaData val serializedSchema = metadata .getKeyValueMetaData - .toMap + .asScala.toMap .get(CatalystReadSupport.SPARK_METADATA_KEY) if (serializedSchema.isEmpty) { // Falls back to Parquet schema if no Spark SQL schema found. @@ -745,7 +746,7 @@ private[sql] object ParquetRelation extends Logging { // Reads footers in multi-threaded manner within each task val footers = ParquetFileReader.readAllFootersInParallel( - serializedConf.value, fakeFileStatuses, skipRowGroups) + serializedConf.value, fakeFileStatuses.asJava, skipRowGroups).asScala // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` val converter = @@ -772,7 +773,7 @@ private[sql] object ParquetRelation extends Logging { val fileMetaData = footer.getParquetMetadata.getFileMetaData fileMetaData .getKeyValueMetaData - .toMap + .asScala.toMap .get(CatalystReadSupport.SPARK_METADATA_KEY) .flatMap(deserializeSchemaString) .getOrElse(converter.convert(fileMetaData.getSchema)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 42376ef7a9c1..142301fe87cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.IOException +import java.util.{Collections, Arrays} -import scala.collection.JavaConversions._ import scala.util.Try import org.apache.hadoop.conf.Configuration @@ -107,7 +107,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ParquetFileWriter.writeMetadataFile( conf, path, - new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) + Arrays.asList(new Footer(path, new ParquetMetadata(metaData, Collections.emptyList())))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index ed282f98b7d7..d800c7456bda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD @@ -92,9 +92,9 @@ case class ShuffledHashOuterJoin( case FullOuter => // TODO(davies): use UnsafeRow val leftHashTable = - buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)) + buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)).asScala val rightHashTable = - buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)) + buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)).asScala (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => fullOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 59f8b079ab33..5a58d846ad80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.io.OutputStream import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import net.razorvine.pickle._ @@ -196,14 +196,15 @@ object EvaluatePython { case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray) + new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) case (c, ArrayType(elementType, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val keys = c.keysIterator.map(fromJava(_, keyType)).toArray - val values = c.valuesIterator.map(fromJava(_, valueType)).toArray + val keyValues = c.asScala.toSeq + val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray + val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray ArrayBasedMapData(keys, values) case (c, StructType(fields)) if c.getClass.isArray => @@ -367,7 +368,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val pickle = new Unpickler iter.flatMap { pickedResult => val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]] + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala } }.mapPartitions { iter => val row = new GenericMutableRow(1) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 7abdd3db8034..4867cebf5328 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -23,7 +23,7 @@ import java.util.List; import java.util.Map; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; import scala.collection.Seq; import com.google.common.collect.ImmutableMap; @@ -96,7 +96,7 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); - + // Varargs with mathfunctions DataFrame df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); @@ -172,7 +172,7 @@ public void testCreateDataFrameFromJavaBeans() { Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer))); + Ints.toArray(JavaConverters.seqAsJavaListConverter(outputBuffer).asJava())); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { @@ -206,7 +206,7 @@ public void testCrosstab() { count++; } } - + @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index cdaa14ac8078..329ffb66083b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.sql.test.SharedSQLContext @@ -153,11 +153,11 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { // Test Java version checkAnswer( - df.na.fill(mapAsJavaMap(Map( + df.na.fill(Map( "a" -> "test", "c" -> 1, "d" -> 2.2 - ))), + ).asJava), Row("test", null, 1, 2.2)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 4adcefb7dc4b..3649c2a97b5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ @@ -145,7 +145,7 @@ object QueryTest { } def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(df, expectedAnswer.toSeq) match { + checkAnswer(df, expectedAnswer.asScala) match { case Some(errorMessage) => errorMessage case None => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 45db619567a2..bd7cf8c10abe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConverters.seqAsJavaListConverter -import scala.collection.JavaConverters.mapAsJavaMapConverter +import scala.collection.JavaConverters._ import org.apache.avro.Schema import org.apache.avro.generic.IndexedRecord diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index d85c564e3e8d..df68432faeeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.parquet.hadoop.ParquetFileReader @@ -40,8 +40,9 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq override def accept(path: Path): Boolean = pathFilter(path) }).toSeq - val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) - footers.head.getParquetMetadata.getFileMetaData.getSchema + val footers = + ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true) + footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema } protected def logParquetSchema(path: String): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index e6b0a2ea95e3..08d2b9dee99b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConversions._ +import java.util.Collections + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -28,7 +30,7 @@ import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.example.data.{Group, GroupWriter} import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.metadata.{BlockMetaData, CompressionCodecName, FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} @@ -205,9 +207,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("compression codec") { def compressionCodecFor(path: String): String = { val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)) - .getBlocks - .flatMap(_.getColumns) + .readMetaData(new Path(path), Some(configuration)).getBlocks.asScala + .flatMap(_.getColumns.asScala) .map(_.getCodec.name()) .distinct @@ -348,14 +349,16 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { """.stripMargin) withTempPath { location => - val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) + val extraMetadata = Collections.singletonMap( + CatalystReadSupport.SPARK_METADATA_KEY, sparkSchema.toString) val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( sqlContext.sparkContext.hadoopConfiguration, path, - new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) + Collections.singletonList( + new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( @@ -386,7 +389,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } @@ -410,7 +413,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } @@ -434,7 +437,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } } @@ -481,7 +484,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 02cc7e5efa52..306f98bcb534 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} import java.util.concurrent.RejectedExecutionException -import java.util.{Map => JMap, UUID} +import java.util.{Arrays, Map => JMap, UUID} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} import scala.util.control.NonFatal @@ -126,13 +126,13 @@ private[hive] class SparkExecuteStatementOperation( def getResultSetSchema: TableSchema = { if (result == null || result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) } else { logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") val schema = result.queryExecution.analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } - new TableSchema(schema) + new TableSchema(schema.asJava) } } @@ -298,7 +298,7 @@ private[hive] class SparkExecuteStatementOperation( sqlOperationConf = new HiveConf(sqlOperationConf) // apply overlay query specific settings, if any - getConfOverlay().foreach { case (k, v) => + getConfOverlay().asScala.foreach { case (k, v) => try { sqlOperationConf.verifyAndSet(k, v) } catch { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7799704c819d..a29df567983b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import java.io._ import java.util.{ArrayList => JArrayList, Locale} +import scala.collection.JavaConverters._ + import jline.console.ConsoleReader import jline.console.history.FileHistory @@ -101,9 +101,9 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Set all properties specified via command line. val conf: HiveConf = sessionState.getConf - sessionState.cmdProperties.entrySet().foreach { item => - val key = item.getKey.asInstanceOf[String] - val value = item.getValue.asInstanceOf[String] + sessionState.cmdProperties.entrySet().asScala.foreach { item => + val key = item.getKey.toString + val value = item.getValue.toString // We do not propagate metastore options to the execution copy of hive. if (key != "javax.jdo.option.ConnectionURL") { conf.set(key, value) @@ -316,15 +316,15 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { // Print the column names. - Option(driver.getSchema.getFieldSchemas).map { fields => - out.println(fields.map(_.getName).mkString("\t")) + Option(driver.getSchema.getFieldSchemas).foreach { fields => + out.println(fields.asScala.map(_.getName).mkString("\t")) } } var counter = 0 try { while (!out.checkError() && driver.getResults(res)) { - res.foreach{ l => + res.asScala.foreach { l => counter += 1 out.println(l) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 644165acf70a..5ad8c54f296d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,6 +21,8 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import scala.collection.JavaConverters._ + import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.Utils @@ -34,8 +36,6 @@ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import scala.collection.JavaConversions._ - private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, hiveContext: HiveContext) extends CLIService(hiveServer) with ReflectedCompositeService { @@ -76,7 +76,7 @@ private[thriftserver] trait ReflectedCompositeService { this: AbstractService => def initCompositeService(hiveConf: HiveConf) { // Emulating `CompositeService.init(hiveConf)` val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") - serviceList.foreach(_.init(hiveConf)) + serviceList.asScala.foreach(_.init(hiveConf)) // Emulating `AbstractService.init(hiveConf)` invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 77272aecf283..2619286afc14 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.{ArrayList => JArrayList, List => JList} +import java.util.{Arrays, ArrayList => JArrayList, List => JList} + +import scala.collection.JavaConverters._ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} @@ -27,8 +29,6 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import scala.collection.JavaConversions._ - private[hive] class SparkSQLDriver( val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver @@ -43,14 +43,14 @@ private[hive] class SparkSQLDriver( private def getResultSetSchema(query: context.QueryExecution): Schema = { val analyzed = query.analyzed logDebug(s"Result Schema: ${analyzed.output}") - if (analyzed.output.size == 0) { - new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + if (analyzed.output.isEmpty) { + new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null) } else { val fieldSchemas = analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } - new Schema(fieldSchemas, null) + new Schema(fieldSchemas.asJava, null) } } @@ -79,7 +79,7 @@ private[hive] class SparkSQLDriver( if (hiveResponse == null) { false } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava) hiveResponse = null true } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 1d41c4613182..bacf6cc458fd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.PrintStream -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.HiveContext @@ -64,7 +64,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { - hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => + hiveContext.hiveconf.getAllProperties.asScala.toSeq.sorted.foreach { case (k, v) => logDebug(s"HiveConf var: $k=$v") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 17cc83087fb1..c0a458fa9ab8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -22,7 +22,7 @@ import java.net.{URL, URLClassLoader} import java.sql.Timestamp import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions import scala.concurrent.duration._ @@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { logInfo("defalt warehouse location is " + defaltWarehouseLocation) // `configure` goes second to override other settings. - val allConfig = metadataConf.iterator.map(e => e.getKey -> e.getValue).toMap ++ configure + val allConfig = metadataConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configure val isolatedLoader = if (hiveMetastoreJars == "builtin") { if (hiveExecutionVersion != hiveMetastoreVersion) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 64fffdbf9b02..cfe2bb05ad89 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import scala.collection.JavaConverters._ + import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} @@ -31,9 +33,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, types} import org.apache.spark.unsafe.types.UTF8String -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * 1. The Underlying data type in catalyst and in Hive * In catalyst: @@ -290,13 +289,13 @@ private[hive] trait HiveInspectors { DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data - val map = mi.getWritableConstantValue - val keys = map.keysIterator.map(unwrap(_, mi.getMapKeyObjectInspector)).toArray - val values = map.valuesIterator.map(unwrap(_, mi.getMapValueObjectInspector)).toArray + val keyValues = mi.getWritableConstantValue.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray ArrayBasedMapData(keys, values) case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - val values = li.getWritableConstantValue + val values = li.getWritableConstantValue.asScala .map(unwrap(_, li.getListElementObjectInspector)) .toArray new GenericArrayData(values) @@ -342,7 +341,7 @@ private[hive] trait HiveInspectors { case li: ListObjectInspector => Option(li.getList(data)) .map { l => - val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray + val values = l.asScala.map(unwrap(_, li.getListElementObjectInspector)).toArray new GenericArrayData(values) } .orNull @@ -351,15 +350,16 @@ private[hive] trait HiveInspectors { if (map == null) { null } else { - val keys = map.keysIterator.map(unwrap(_, mi.getMapKeyObjectInspector)).toArray - val values = map.valuesIterator.map(unwrap(_, mi.getMapValueObjectInspector)).toArray + val keyValues = map.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray ArrayBasedMapData(keys, values) } // currently, hive doesn't provide the ConstantStructObjectInspector case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs - InternalRow.fromSeq( - allRefs.map(r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) + InternalRow.fromSeq(allRefs.asScala.map( + r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) } @@ -403,14 +403,14 @@ private[hive] trait HiveInspectors { case soi: StandardStructObjectInspector => val schema = dataType.asInstanceOf[StructType] - val wrappers = soi.getAllStructFieldRefs.zip(schema.fields).map { case (ref, field) => - wrapperFor(ref.getFieldObjectInspector, field.dataType) + val wrappers = soi.getAllStructFieldRefs.asScala.zip(schema.fields).map { + case (ref, field) => wrapperFor(ref.getFieldObjectInspector, field.dataType) } (o: Any) => { if (o != null) { val struct = soi.create() val row = o.asInstanceOf[InternalRow] - soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach { + soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { case ((field, wrapper), i) => soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) } @@ -537,7 +537,7 @@ private[hive] trait HiveInspectors { // 1. create the pojo (most likely) object val result = x.create() var i = 0 - while (i < fieldRefs.length) { + while (i < fieldRefs.size) { // 2. set the property for the pojo val tpe = structType(i).dataType x.setStructFieldData( @@ -552,9 +552,9 @@ private[hive] trait HiveInspectors { val fieldRefs = x.getAllStructFieldRefs val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] - val result = new java.util.ArrayList[AnyRef](fieldRefs.length) + val result = new java.util.ArrayList[AnyRef](fieldRefs.size) var i = 0 - while (i < fieldRefs.length) { + while (i < fieldRefs.size) { val tpe = structType(i).dataType result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 @@ -712,10 +712,10 @@ private[hive] trait HiveInspectors { def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { case s: StructObjectInspector => - StructType(s.getAllStructFieldRefs.map(f => { + StructType(s.getAllStructFieldRefs.asScala.map(f => types.StructField( f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - })) + )) case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) case m: MapObjectInspector => MapType( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 98d21aa76d64..b8da0840ae56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import com.google.common.base.Objects @@ -483,7 +483,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // are empty. val partitions = metastoreRelation.getHiveQlPartitions().map { p => val location = p.getLocation - val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) }) ParquetPartition(values, location) @@ -798,9 +798,9 @@ private[hive] case class MetastoreRelation val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tTable.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) tTable.setPartitionKeys( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) table.location.foreach(sd.setLocation) table.inputFormat.foreach(sd.setInputFormat) @@ -852,11 +852,11 @@ private[hive] case class MetastoreRelation val tPartition = new org.apache.hadoop.hive.metastore.api.Partition tPartition.setDbName(databaseName) tPartition.setTableName(tableName) - tPartition.setValues(p.values) + tPartition.setValues(p.values.asJava) val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) sd.setLocation(p.storage.location) sd.setInputFormat(p.storage.inputFormat) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ad33dee555dd..d5cd7e98b526 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive import java.sql.Date import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.serde.serdeConstants @@ -48,10 +51,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler -/* Implicit conversions */ -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer - /** * Used when we need to start parsing the AST before deciding that we are going to pass the command * back for Hive to execute natively. Will be replaced with a native command that contains the @@ -202,7 +201,7 @@ private[hive] object HiveQl extends Logging { * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. */ private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.toSeq).getOrElse(Nil) + Option(s).map(_.asScala).getOrElse(Nil) /** * Returns this ASTNode with the text changed to `newText`. @@ -217,7 +216,7 @@ private[hive] object HiveQl extends Logging { */ def withChildren(newChildren: Seq[ASTNode]): ASTNode = { (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren) + n.addChildren(newChildren.asJava) n } @@ -323,11 +322,11 @@ private[hive] object HiveQl extends Logging { assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") val tableOps = tree.getChildren val colList = - tableOps + tableOps.asScala .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") .getOrElse(sys.error("No columnList!")).getChildren - colList.map(nodeToAttribute) + colList.asScala.map(nodeToAttribute) } /** Extractor for matching Hive's AST Tokens. */ @@ -337,7 +336,7 @@ private[hive] object HiveQl extends Logging { case t: ASTNode => CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, - Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) case _ => None } } @@ -424,7 +423,9 @@ private[hive] object HiveQl extends Logging { protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { val (db, tableName) = - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => (None, tableOnly) case Seq(databaseName, table) => (Some(databaseName), table) } @@ -433,7 +434,9 @@ private[hive] object HiveQl extends Logging { } protected def extractTableIdent(tableNameParts: Node): Seq[String] = { - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => Seq(tableOnly) case Seq(databaseName, table) => Seq(databaseName, table) case other => sys.error("Hive only supports tables names like 'tableName' " + @@ -624,7 +627,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val cols = BaseSemanticAnalyzer.getColumns(list, true) if (cols != null) { tableDesc = tableDesc.copy( - schema = cols.map { field => + schema = cols.asScala.map { field => HiveColumn(field.getName, field.getType, field.getComment) }) } @@ -636,7 +639,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val cols = BaseSemanticAnalyzer.getColumns(list(0), false) if (cols != null) { tableDesc = tableDesc.copy( - partitionColumns = cols.map { field => + partitionColumns = cols.asScala.map { field => HiveColumn(field.getName, field.getType, field.getComment) }) } @@ -672,7 +675,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case _ => assert(false) } tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams) + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) case Token("TOK_TABLELOCATION", child :: Nil) => var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) location = EximUtil.relativeToAbsolutePath(hiveConf, location) @@ -684,7 +687,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val serdeParams = new java.util.HashMap[String, String]() BaseSemanticAnalyzer.readProps( (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) - tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) } case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => child.getText().toLowerCase(Locale.ENGLISH) match { @@ -847,7 +851,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.getChildren.toSeq + val Seq(whereExpr) = whereNode.getChildren.asScala Filter(nodeToExpr(whereExpr), relations) }.getOrElse(relations) @@ -856,7 +860,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Script transformations are expressed as a select clause with a single expression of type // TOK_TRANSFORM - val transformation = select.getChildren.head match { + val transformation = select.getChildren.iterator().next() match { case Token("TOK_SELEXPR", Token("TOK_TRANSFORM", Token("TOK_EXPLIST", inputExprs) :: @@ -925,10 +929,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withLateralView = lateralViewClause.map { lv => val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head + Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.iterator().next() - val alias = - getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText val (generator, attributes) = nodesToGenerator(clauses) Generate( @@ -944,7 +948,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq + select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias) Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => @@ -973,7 +977,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Handle HAVING clause. val withHaving = havingClause.map { h => - val havingExpr = h.getChildren.toSeq match { case Seq(hexpr) => nodeToExpr(hexpr) } + val havingExpr = h.getChildren.asScala match { case Seq(hexpr) => nodeToExpr(hexpr) } // Note that we added a cast to boolean. If the expression itself is already boolean, // the optimizer will get rid of the unnecessary cast. Filter(Cast(havingExpr, BooleanType), withProject) @@ -983,32 +987,42 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withDistinct = if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - // Handle ORDER BY, SORT BY, DISTRIBETU BY, and CLUSTER BY clause. + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), true, withDistinct) + Sort(totalOrdering.getChildren.asScala.map(nodeToSortOrder), true, withDistinct) case (None, Some(perPartitionOrdering), None, None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withDistinct) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), + false, withDistinct) case (None, None, Some(partitionExprs), None) => - RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), withDistinct) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, - RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), false, + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) case (None, None, None, Some(clusterExprs)) => - Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, - RepartitionByExpression(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) + Sort( + clusterExprs.getChildren.asScala.map(nodeToExpr).map(SortOrder(_, Ascending)), + false, + RepartitionByExpression( + clusterExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) case (None, None, None, None) => withDistinct case _ => sys.error("Unsupported set of ordering / distribution clauses.") } val withLimit = - limitClause.map(l => nodeToExpr(l.getChildren.head)) + limitClause.map(l => nodeToExpr(l.getChildren.iterator().next())) .map(Limit(_, withSort)) .getOrElse(withSort) // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.getChildren.toSeq.collect { + val windowDefinitions = windowClause.map(_.getChildren.asScala.collect { case Token("TOK_WINDOWDEF", Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => windowName -> nodesToWindowSpecification(spec) @@ -1063,7 +1077,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = selectClause - val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText val (generator, attributes) = nodesToGenerator(clauses) Generate( @@ -1092,7 +1107,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val tableIdent = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => Seq(tableOnly) case Seq(databaseName, table) => Seq(databaseName, table) case other => sys.error("Hive only supports tables names like 'tableName' " + @@ -1139,7 +1156,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i))) - val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr)) + val joinExpressions = + tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) val joinConditions = joinExpressions.sliding(2).map { case Seq(c1, c2) => @@ -1164,7 +1182,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C joinType = joinType.remove(joinType.length - 1)) } - val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i)))) + val groups = joinExpressions.head.indices.map(i => Coalesce(joinExpressions.map(_(i)))) // Unique join is not really the same as an outer join so we must group together results where // the joinExpressions are the same, taking the First of each value is only okay because the @@ -1229,7 +1247,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val tableIdent = extractTableIdent(tableNameParts) - val partitionKeys = partitionClause.map(_.getChildren.map { + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { // Parse partitions. We also make keys case insensitive. case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) @@ -1249,7 +1267,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val tableIdent = extractTableIdent(tableNameParts) - val partitionKeys = partitionClause.map(_.getChildren.map { + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { // Parse partitions. We also make keys case insensitive. case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) @@ -1590,18 +1608,18 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = getClauses( Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.getChildren.toSeq.asInstanceOf[Seq[ASTNode]]) + partitionAndOrdering.getChildren.asScala.asInstanceOf[Seq[ASTNode]]) (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.getChildren.map(nodeToExpr), - orderByExpr.getChildren.map(nodeToSortOrder)) + (partitionByExpr.getChildren.asScala.map(nodeToExpr), + orderByExpr.getChildren.asScala.map(nodeToSortOrder)) case (Some(partitionByExpr), None, None) => - (partitionByExpr.getChildren.map(nodeToExpr), Nil) + (partitionByExpr.getChildren.asScala.map(nodeToExpr), Nil) case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.getChildren.map(nodeToSortOrder)) + (Nil, orderByExpr.getChildren.asScala.map(nodeToSortOrder)) case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.getChildren.map(nodeToExpr) + val expressions = clusterByExpr.getChildren.asScala.map(nodeToExpr) (expressions, expressions.map(SortOrder(_, Ascending))) case _ => throw new NotImplementedError( @@ -1639,7 +1657,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } rowFrame.orElse(rangeFrame).map { frame => - frame.getChildren.toList match { + frame.getChildren.asScala.toList match { case precedingNode :: followingNode :: Nil => SpecifiedWindowFrame( frameType, @@ -1701,7 +1719,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case other => sys.error(s"Non ASTNode encountered: $other") } - Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) + Option(node.getChildren).map(_.asScala).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) builder } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 267074f3ad10..004805f3aed0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -22,8 +22,7 @@ import java.rmi.server.UID import org.apache.avro.Schema -/* Implicit conversions */ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -73,7 +72,7 @@ private[hive] object HiveShim { */ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { if (ids != null && ids.nonEmpty) { - ColumnProjectionUtils.appendReadColumns(conf, ids) + ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) } if (names != null && names.nonEmpty) { appendReadColumnNames(conf, names) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index f49c97de8ff4..4d1e3ed9198e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -21,7 +21,7 @@ import java.io.{File, PrintStream} import java.util.{Map => JMap} import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path @@ -305,10 +305,11 @@ private[hive] class ClientWrapper( HiveTable( name = h.getTableName, specifiedDatabase = Option(h.getDbName), - schema = h.getCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - partitionColumns = h.getPartCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - properties = h.getParameters.toMap, - serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.toMap, + schema = h.getCols.asScala.map(f => HiveColumn(f.getName, f.getType, f.getComment)), + partitionColumns = h.getPartCols.asScala.map(f => + HiveColumn(f.getName, f.getType, f.getComment)), + properties = h.getParameters.asScala.toMap, + serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.asScala.toMap, tableType = h.getTableType match { case HTableType.MANAGED_TABLE => ManagedTable case HTableType.EXTERNAL_TABLE => ExternalTable @@ -334,9 +335,9 @@ private[hive] class ClientWrapper( private def toQlTable(table: HiveTable): metadata.Table = { val qlTable = new metadata.Table(table.database, table.name) - qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) qlTable.setPartCols( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) table.properties.foreach { case (k, v) => qlTable.setProperty(k, v) } table.serdeProperties.foreach { case (k, v) => qlTable.setSerdeParam(k, v) } @@ -366,13 +367,13 @@ private[hive] class ClientWrapper( private def toHivePartition(partition: metadata.Partition): HivePartition = { val apiPartition = partition.getTPartition HivePartition( - values = Option(apiPartition.getValues).map(_.toSeq).getOrElse(Seq.empty), + values = Option(apiPartition.getValues).map(_.asScala).getOrElse(Seq.empty), storage = HiveStorageDescriptor( location = apiPartition.getSd.getLocation, inputFormat = apiPartition.getSd.getInputFormat, outputFormat = apiPartition.getSd.getOutputFormat, serde = apiPartition.getSd.getSerdeInfo.getSerializationLib, - serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.toMap)) + serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) } override def getPartitionOption( @@ -397,7 +398,7 @@ private[hive] class ClientWrapper( } override def listTables(dbName: String): Seq[String] = withHiveState { - client.getAllTables(dbName) + client.getAllTables(dbName).asScala } /** @@ -514,17 +515,17 @@ private[hive] class ClientWrapper( } def reset(): Unit = withHiveState { - client.getAllTables("default").foreach { t => + client.getAllTables("default").asScala.foreach { t => logDebug(s"Deleting table $t") val table = client.getTable("default", t) - client.getIndexes("default", t, 255).foreach { index => + client.getIndexes("default", t, 255).asScala.foreach { index => shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { client.dropTable("default", t) } } - client.getAllDatabases.filterNot(_ == "default").foreach { db => + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => logDebug(s"Dropping Database: $db") client.dropDatabase(db, true, false, true) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 8fc8935b1dc3..48bbb21e6c1d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -23,7 +23,7 @@ import java.net.URI import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf @@ -201,7 +201,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { setDataLocationMethod.invoke(table, new URI(loc)) override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq override def getPartitionsByFilter( hive: Hive, @@ -220,7 +220,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[String]() getDriverResultsMethod.invoke(driver, res) - res.toSeq + res.asScala } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { @@ -310,7 +310,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { setDataLocationMethod.invoke(table, new Path(loc)) override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq /** * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. @@ -320,7 +320,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { */ def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys + val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) .map(col => col.getName).toSet @@ -354,7 +354,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] } - partitions.toSeq + partitions.asScala.toSeq } override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = @@ -363,7 +363,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[Object]() getDriverResultsMethod.invoke(driver, res) - res.map { r => + res.asScala.map { r => r match { case s: String => s case a: Array[Object] => a(0).asInstanceOf[String] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 5f0ed5393d19..441b6b6033e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.metastore.api.FieldSchema @@ -39,8 +39,8 @@ case class DescribeHiveTableCommand( // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil - val columns: Seq[FieldSchema] = table.hiveQlTable.getCols - val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala results ++= columns.map(field => (field.getName, field.getType, field.getComment)) if (partitionColumns.nonEmpty) { val partColumnInfo = @@ -48,7 +48,7 @@ case class DescribeHiveTableCommand( results ++= partColumnInfo ++ Seq(("# Partition Information", "", "")) ++ - Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ + Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++ partColumnInfo } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index ba7eb15a1c0c..806d2b9b0b7d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} @@ -98,7 +98,7 @@ case class HiveTableScan( .asInstanceOf[StructObjectInspector] val columnTypeNames = structOI - .getAllStructFieldRefs + .getAllStructFieldRefs.asScala .map(_.getFieldObjectInspector) .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName) .mkString(",") @@ -118,9 +118,8 @@ case class HiveTableScan( case None => partitions case Some(shouldKeep) => partitions.filter { part => val dataTypes = relation.partitionKeys.map(_.dataType) - val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { - castFromString(value, dataType) - } + val castedValues = part.getValues.asScala.zip(dataTypes) + .map { case (value, dataType) => castFromString(value, dataType) } // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 62efda613a17..58f7fa640e8a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.hive.execution import java.util +import scala.collection.JavaConverters._ + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer @@ -38,8 +39,6 @@ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.DataType import org.apache.spark.{SparkException, TaskContext} - -import scala.collection.JavaConversions._ import org.apache.spark.util.SerializableJobConf private[hive] @@ -94,7 +93,8 @@ case class InsertIntoHiveTable( ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val fieldOIs = standardOI.getAllStructFieldRefs.asScala + .map(_.getFieldObjectInspector).toArray val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} val outputData = new Array[Any](fieldOIs.length) @@ -198,7 +198,7 @@ case class InsertIntoHiveTable( // loadPartition call orders directories created on the iteration order of the this map val orderedPartitionSpec = new util.LinkedHashMap[String, String]() - table.hiveQlTable.getPartCols().foreach { entry => + table.hiveQlTable.getPartCols.asScala.foreach { entry => orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } @@ -226,7 +226,7 @@ case class InsertIntoHiveTable( val oldPart = catalog.client.getPartitionOption( catalog.client.getTable(table.databaseName, table.tableName), - partitionSpec) + partitionSpec.asJava) if (oldPart.isEmpty || !ifNotExists) { catalog.client.loadPartition( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index ade27454b9d2..c7651daffe36 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -21,7 +21,7 @@ import java.io._ import java.util.Properties import javax.annotation.Nullable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.hadoop.hive.serde.serdeConstants @@ -61,7 +61,7 @@ case class ScriptTransformation( protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd) + val builder = new ProcessBuilder(cmd.asJava) val proc = builder.start() val inputStream = proc.getInputStream @@ -172,10 +172,10 @@ case class ScriptTransformation( val fieldList = outputSoi.getAllStructFieldRefs() var i = 0 while (i < dataList.size()) { - if (dataList(i) == null) { + if (dataList.get(i) == null) { mutableRow.setNullAt(i) } else { - mutableRow(i) = unwrap(dataList(i), fieldList(i).getFieldObjectInspector) + mutableRow(i) = unwrap(dataList.get(i), fieldList.get(i).getFieldObjectInspector) } i += 1 } @@ -307,7 +307,7 @@ case class HiveScriptIOSchema ( val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) val fieldObjectInspectors = columnTypes.map(toInspector) val objectInspector = ObjectInspectorFactory - .getStandardStructObjectInspector(columns, fieldObjectInspectors) + .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) .asInstanceOf[ObjectInspector] (serde, objectInspector) } @@ -342,7 +342,7 @@ case class HiveScriptIOSchema ( propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) val properties = new Properties() - properties.putAll(propsMap) + properties.putAll(propsMap.asJava) serde.initialize(null, properties) serde diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 7182246e466a..cad02373e5ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Try import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} @@ -81,8 +81,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) /* List all of the registered function names. */ override def listFunction(): Seq[String] = { - val a = FunctionRegistry.getFunctionNames ++ underlying.listFunction() - a.toList.sorted + (FunctionRegistry.getFunctionNames.asScala ++ underlying.listFunction()).toList.sorted } /* Get the class of the registered function by specified name. */ @@ -116,7 +115,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient private lazy val method = - function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) + function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) @transient private lazy val arguments = children.map(toInspector).toArray @@ -541,7 +540,7 @@ private[hive] case class HiveGenericUDTF( @transient protected lazy val collector = new UDTFCollector - lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { + lazy val elementTypes = outputInspector.getAllStructFieldRefs.asScala.map { field => (inspectorToDataType(field.getFieldObjectInspector), true) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 9f4f8b5789af..1cff5cf9c354 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.orc import java.util.Properties +import scala.collection.JavaConverters._ + import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -43,9 +45,6 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -/* Implicit conversions */ -import scala.collection.JavaConversions._ - private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "orc" @@ -97,7 +96,8 @@ private[orc] class OrcOutputWriter( private val reusableOutputBuffer = new Array[Any](dataSchema.length) // Used to convert Catalyst values into Hadoop `Writable`s. - private val wrappers = structOI.getAllStructFieldRefs.zip(dataSchema.fields.map(_.dataType)) + private val wrappers = structOI.getAllStructFieldRefs.asScala + .zip(dataSchema.fields.map(_.dataType)) .map { case (ref, dt) => wrapperFor(ref.getFieldObjectInspector, dt) }.toArray diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 4da86636ac10..572eaebe81ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.test import java.io.File import java.util.{Set => JavaSet} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions @@ -37,9 +38,6 @@ import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} -/* Implicit conversions */ -import scala.collection.JavaConversions._ - // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( @@ -405,7 +403,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { def reset() { try { // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + org.apache.log4j.LogManager.getCurrentLoggers.asScala.foreach { log => log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } @@ -415,9 +413,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName => - FunctionRegistry.unregisterTemporaryUDF(udfName) - } + FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). + foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } // Some tests corrupt this value on purpose, which breaks the RESET call below. hiveconf.set("fs.default.name", new File(".").toURI.toString) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 0efcf80bd4ea..5e7b93d45710 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.client -import scala.collection.JavaConversions._ +import java.util.Collections import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.serde.serdeConstants @@ -38,7 +38,7 @@ class FiltersSuite extends SparkFunSuite with Logging { private val varCharCol = new FieldSchema() varCharCol.setName("varchar") varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) - testTable.setPartCols(varCharCol :: Nil) + testTable.setPartCols(Collections.singletonList(varCharCol)) filterTest("string filter", (a("stringcol", StringType) > Literal("test")) :: Nil, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index b03a35132325..9c10ffe1113d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.{DataInput, DataOutput} -import java.util -import java.util.Properties +import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} @@ -33,8 +32,6 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.util.Utils -import scala.collection.JavaConversions._ - case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) // Case classes for the custom UDF's. @@ -326,11 +323,11 @@ class PairSerDe extends AbstractSerDe { override def getObjectInspector: ObjectInspector = { ObjectInspectorFactory .getStandardStructObjectInspector( - Seq("pair"), - Seq(ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector)) + Arrays.asList("pair"), + Arrays.asList(ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector)) )) } @@ -343,10 +340,10 @@ class PairSerDe extends AbstractSerDe { override def deserialize(value: Writable): AnyRef = { val pair = value.asInstanceOf[TestPair] - val row = new util.ArrayList[util.ArrayList[AnyRef]] - row.add(new util.ArrayList[AnyRef](2)) - row(0).add(Integer.valueOf(pair.entry._1)) - row(0).add(Integer.valueOf(pair.entry._2)) + val row = new ArrayList[ArrayList[AnyRef]] + row.add(new ArrayList[AnyRef](2)) + row.get(0).add(Integer.valueOf(pair.entry._1)) + row.get(0).add(Integer.valueOf(pair.entry._2)) row } @@ -355,9 +352,9 @@ class PairSerDe extends AbstractSerDe { class PairUDF extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector) + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) ) override def evaluate(args: Array[DeferredObject]): AnyRef = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 3bf8f3ac2048..210d56674541 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.TestHive -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * A set of test cases that validate partition and column pruning. */ @@ -161,7 +160,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch") assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch") - val actualPartitions = actualPartValues.map(_.toSeq.mkString(",")).sorted + val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted assert(actualPartitions === expectedPartitions, "Partitions selected do not match") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 55ecbd5b5f21..1ff1d9a2934c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution import java.sql.{Date, Timestamp} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect @@ -164,7 +164,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("show functions") { val allFunctions = (FunctionRegistry.builtin.listFunction().toSet[String] ++ - org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames).toList.sorted + org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) checkAnswer(sql("SHOW functions abs"), Row("abs")) checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 5bbca14bad32..7966b43596e7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.sources -import java.sql.Date - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -552,7 +550,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } @@ -600,7 +598,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 214cd80108b9..edfa474677f1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -17,11 +17,10 @@ package org.apache.spark.streaming.api.java -import java.util import java.lang.{Long => JLong} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -145,8 +144,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * an array. */ def glom(): JavaDStream[JList[T]] = - new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - + new JavaDStream(dstream.glom().map(_.toSeq.asJava)) /** Return the [[org.apache.spark.streaming.StreamingContext]] associated with this DStream */ @@ -191,7 +189,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -204,7 +202,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -282,7 +280,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return all the RDDs between 'fromDuration' to 'toDuration' (both included) */ def slice(fromTime: Time, toTime: Time): JList[R] = { - new util.ArrayList(dstream.slice(fromTime, toTime).map(wrapRDD(_)).toSeq) + dstream.slice(fromTime, toTime).map(wrapRDD).asJava } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 26383e420101..e2aec6c2f63e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.api.java import java.lang.{Long => JLong, Iterable => JIterable} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -116,14 +116,14 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey().mapValues(asJavaIterable _) + dstream.groupByKey().mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(numPartitions).mapValues(asJavaIterable _) + dstream.groupByKey(numPartitions).mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. @@ -132,7 +132,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(partitioner).mapValues(asJavaIterable _) + dstream.groupByKey(partitioner).mapValues(_.asJava) /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are @@ -197,7 +197,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration).mapValues(_.asJava) } /** @@ -212,7 +212,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) : JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(_.asJava) } /** @@ -228,8 +228,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) : JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) - .mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions).mapValues(_.asJava) } /** @@ -248,8 +247,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( slideDuration: Duration, partitioner: Partitioner ): JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) - .mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner).mapValues(_.asJava) } /** @@ -431,7 +429,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { - val list: JList[V] = values + val list: JList[V] = values.asJava val scalaState: Optional[S] = JavaUtils.optionToOptional(state) val result: Optional[S] = in.apply(list, scalaState) result.isPresent match { @@ -539,7 +537,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream).mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -551,8 +549,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( numPartitions: Int ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, numPartitions) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, numPartitions).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -564,8 +561,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( partitioner: Partitioner ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, partitioner) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, partitioner).mapValues(t => (t._1.asJava, t._2.asJava)) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 35cc3ce5cf46..13f371f29603 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -21,7 +21,7 @@ import java.lang.{Boolean => JBoolean} import java.io.{Closeable, InputStream} import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} @@ -115,7 +115,13 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new StreamingContext(master, appName, batchDuration, sparkHome, jars, environment)) + this(new StreamingContext( + master, + appName, + batchDuration, + sparkHome, + jars, + environment.asScala)) /** * Create a JavaStreamingContext using an existing JavaSparkContext. @@ -197,7 +203,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).iterator().asScala implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -432,7 +438,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue) } @@ -456,7 +462,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime) } @@ -481,7 +487,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd) } @@ -500,7 +506,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ def union[T](first: JavaDStream[T], rest: JList[JavaDStream[T]]): JavaDStream[T] = { - val dstreams: Seq[DStream[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[T]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[T] = first.classTag ssc.union(dstreams)(cm) } @@ -512,7 +518,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { first: JavaPairDStream[K, V], rest: JList[JavaPairDStream[K, V]] ): JavaPairDStream[K, V] = { - val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[(K, V)] = first.classTag implicit val kcm: ClassTag[K] = first.kManifest implicit val vcm: ClassTag[V] = first.vManifest @@ -534,12 +540,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ): JavaDStream[T] = { implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** @@ -559,12 +564,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index d06401245ff1..2c373640d2fd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -20,14 +20,13 @@ package org.apache.spark.streaming.api.python import java.io.{ObjectInputStream, ObjectOutputStream} import java.lang.reflect.Proxy import java.util.{ArrayList => JArrayList, List => JList} -import scala.collection.JavaConversions._ + import scala.collection.JavaConverters._ import scala.language.existentials import py4j.GatewayServer import org.apache.spark.api.java._ -import org.apache.spark.api.python._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Interval, Duration, Time} @@ -161,7 +160,7 @@ private[python] object PythonDStream { */ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] - rdds.forall(queue.add(_)) + rdds.asScala.foreach(queue.add) queue } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 554aae0117b2..2252e28f22af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.storage.StorageLevel import org.apache.spark.annotation.DeveloperApi @@ -144,12 +144,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: java.util.Iterator[T], metadata: Any) { - supervisor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator.asScala, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: java.util.Iterator[T]) { - supervisor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator.asScala, None, None) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 6d4cdc4aa6b1..0cd39594ee92 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.{Failure, Success} import org.apache.spark.Logging @@ -128,7 +128,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def getPendingTimes(): Seq[Time] = { - jobSets.keySet.toSeq + jobSets.asScala.keys.toSeq } def reportError(msg: String, e: Throwable) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 53b96d51c918..f2711d1355e6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.scheduler import java.nio.ByteBuffer +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions @@ -196,8 +197,7 @@ private[streaming] class ReceivedBlockTracker( writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") - import scala.collection.JavaConversions._ - writeAheadLog.readAll().foreach { byteBuffer => + writeAheadLog.readAll().asScala.foreach { byteBuffer => logTrace("Recovering record " + byteBuffer) Utils.deserialize[ReceivedBlockTrackerLogEvent]( byteBuffer.array, Thread.currentThread().getContextClassLoader) match { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index fe6328b1ce72..9f4a4d6806ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps @@ -118,7 +119,6 @@ private[streaming] class FileBasedWriteAheadLog( * hence the implementation is kept simple. */ def readAll(): JIterator[ByteBuffer] = synchronized { - import scala.collection.JavaConversions._ val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) @@ -126,7 +126,7 @@ private[streaming] class FileBasedWriteAheadLog( logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) - } flatMap { x => x } + }.flatten.asJava } /** diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index bb80bff6dc2e..57b50bdfd652 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -17,16 +17,13 @@ package org.apache.spark.streaming -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.{List => JList} + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import java.util.{List => JList} -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} -import org.apache.spark.streaming._ -import java.util.ArrayList -import collection.JavaConversions._ import org.apache.spark.api.java.JavaRDDLike -import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.api.java.{JavaDStreamLike, JavaDStream, JavaStreamingContext} /** Exposes streaming test functionality in a Java-friendly way. */ trait JavaTestBase extends TestSuiteBase { @@ -39,7 +36,7 @@ trait JavaTestBase extends TestSuiteBase { ssc: JavaStreamingContext, data: JList[JList[T]], numPartitions: Int) = { - val seqData = data.map(Seq(_:_*)) + val seqData = data.asScala.map(_.asScala) implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] @@ -72,9 +69,7 @@ trait JavaTestBase extends TestSuiteBase { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] ssc.getState() val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[V]]() - res.map(entry => out.append(new ArrayList[V](entry))) - out + res.map(_.asJava).asJava } /** @@ -90,12 +85,7 @@ trait JavaTestBase extends TestSuiteBase { implicit val cm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[JList[V]]]() - res.map{entry => - val lists = entry.map(new ArrayList[V](_)) - out.append(new ArrayList[JList[V]](lists)) - } - out + res.map(entry => entry.map(_.asJava).asJava).asJava } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 325ff7c74c39..5e49fd00769a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -20,6 +20,7 @@ import java.io._ import java.nio.ByteBuffer import java.util +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} @@ -417,9 +418,8 @@ object WriteAheadLogSuite { /** Read all the data in the log file in a directory using the WriteAheadLog class. */ def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { - import scala.collection.JavaConversions._ val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - val data = wal.readAll().map(byteBufferToString).toSeq + val data = wal.readAll().asScala.map(byteBufferToString).toSeq wal.close() data } diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 9418beb6b3e3..a0524cabff2d 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -22,7 +22,7 @@ import java.io.File import java.util.jar.JarFile import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} import scala.util.Try @@ -161,7 +161,7 @@ object GenerateMIMAIgnore { val path = packageName.replace('.', '/') val resources = classLoader.getResources(path) - val jars = resources.filter(x => x.getProtocol == "jar") + val jars = resources.asScala.filter(_.getProtocol == "jar") .map(_.getFile.split(":")(1).split("!")(0)).toSeq jars.flatMap(getClassesFromJar(_, path)) @@ -175,7 +175,7 @@ object GenerateMIMAIgnore { private def getClassesFromJar(jarPath: String, packageName: String) = { import scala.collection.mutable val jar = new JarFile(new File(jarPath)) - val enums = jar.entries().map(_.getName).filter(_.startsWith(packageName)) + val enums = jar.entries().asScala.map(_.getName).filter(_.startsWith(packageName)) val classes = mutable.HashSet[Class[_]]() for (entry <- enums if entry.endsWith(".class")) { try { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index bff585b46cbb..e9a02baafd28 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -25,7 +25,7 @@ import java.security.PrivilegedExceptionAction import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} @@ -511,7 +511,7 @@ private[spark] class Client( val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + stagingDirPath YarnSparkHadoopUtil.get.obtainTokensForNamenodes( nns, hadoopConf, creds, Some(sparkConf.get("spark.yarn.principal"))) - val t = creds.getAllTokens + val t = creds.getAllTokens.asScala .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) .head val newExpiration = t.renew(hadoopConf) @@ -650,8 +650,8 @@ private[spark] class Client( distCacheMgr.setDistArchivesEnv(launchEnv) val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) - amContainer.setLocalResources(localResources) - amContainer.setEnvironment(launchEnv) + amContainer.setLocalResources(localResources.asJava) + amContainer.setEnvironment(launchEnv.asJava) val javaOpts = ListBuffer[String]() @@ -782,7 +782,7 @@ private[spark] class Client( // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList - amContainer.setCommands(printableCommands) + amContainer.setCommands(printableCommands.asJava) logDebug("===============================================================================") logDebug("YARN AM launch context:") @@ -797,7 +797,8 @@ private[spark] class Client( // send the acl settings into YARN to control who has access via YARN interfaces val securityManager = new SecurityManager(sparkConf) - amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager)) + amContainer.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) setupSecurityToken(amContainer) UserGroupInformation.getCurrentUser().addCredentials(credentials) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 4cc50483a17f..9abd09b3cc7a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -20,14 +20,13 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI import java.nio.ByteBuffer +import java.util.Collections -import org.apache.hadoop.fs.Path -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation @@ -40,6 +39,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils class ExecutorRunnable( container: Container, @@ -74,9 +74,9 @@ class ExecutorRunnable( .asInstanceOf[ContainerLaunchContext] val localResources = prepareLocalResources - ctx.setLocalResources(localResources) + ctx.setLocalResources(localResources.asJava) - ctx.setEnvironment(env) + ctx.setEnvironment(env.asJava) val credentials = UserGroupInformation.getCurrentUser().getCredentials() val dob = new DataOutputBuffer() @@ -96,8 +96,9 @@ class ExecutorRunnable( |=============================================================================== """.stripMargin) - ctx.setCommands(commands) - ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + ctx.setCommands(commands.asJava) + ctx.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr).asJava) // If external shuffle service is enabled, register with the Yarn shuffle service already // started on the NodeManager and, if authentication is enabled, provide it with our secret @@ -112,7 +113,7 @@ class ExecutorRunnable( // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) } - ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + ctx.setServiceData(Collections.singletonMap("spark_shuffle", secretBytes)) } // Send the start request to the ContainerManager @@ -314,7 +315,8 @@ class ExecutorRunnable( env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096" } - System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } env } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ccf753e69f4b..5f897cbcb4e9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,9 +21,7 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -39,8 +37,8 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.util.Utils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -164,7 +162,7 @@ private[yarn] class YarnAllocator( * Number of container requests at the given location that have not yet been fulfilled. */ private def getNumPendingAtLocation(location: String): Int = - amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum + amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala.map(_.size).sum /** * Request as many executors from the ResourceManager as needed to reach the desired total. If @@ -231,14 +229,14 @@ private[yarn] class YarnAllocator( numExecutorsRunning, allocateResponse.getAvailableResources)) - handleAllocatedContainers(allocatedContainers) + handleAllocatedContainers(allocatedContainers.asScala) } val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - processCompletedContainers(completedContainers) + processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." .format(completedContainers.size, numExecutorsRunning)) @@ -271,7 +269,7 @@ private[yarn] class YarnAllocator( val request = createContainerRequest(resource, locality.nodes, locality.racks) amClient.addContainerRequest(request) val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last + val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.asScala.last logInfo(s"Container request (host: $hostStr, capability: $resource)") } } else if (missing < 0) { @@ -280,7 +278,8 @@ private[yarn] class YarnAllocator( val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource) if (!matchingRequests.isEmpty) { - matchingRequests.head.take(numToCancel).foreach(amClient.removeContainerRequest) + matchingRequests.iterator().next().asScala + .take(numToCancel).foreach(amClient.removeContainerRequest) } else { logWarning("Expected to find pending requests, but found none.") } @@ -459,7 +458,7 @@ private[yarn] class YarnAllocator( } } - if (allocatedContainerToHostMap.containsKey(containerId)) { + if (allocatedContainerToHostMap.contains(containerId)) { val host = allocatedContainerToHostMap.get(containerId).get val containerSet = allocatedHostToContainersMap.get(host).get diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 4999f9c06210..df042bf291de 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,17 +19,15 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -108,8 +106,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter", classOf[Configuration]) val proxies = method.invoke(null, conf).asInstanceOf[JList[String]] - val hosts = proxies.map { proxy => proxy.split(":")(0) } - val uriBases = proxies.map { proxy => prefix + proxy + proxyBase } + val hosts = proxies.asScala.map { proxy => proxy.split(":")(0) } + val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase } Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) } catch { case e: NoSuchMethodException => diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 128e996b71fe..b4f8049bff57 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, FileOutputStream, OutputStreamWriter} import java.util.Properties import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -132,7 +132,7 @@ abstract class BaseYarnClusterSuite props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - yarnCluster.getConfig().foreach { e => + yarnCluster.getConfig.asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } @@ -149,7 +149,7 @@ abstract class BaseYarnClusterSuite props.store(writer, "Spark properties.") writer.close() - val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil + val extraJarArgs = if (extraJars.nonEmpty) Seq("--jars", extraJars.mkString(",")) else Nil val mainArgs = if (klass.endsWith(".py")) { Seq(klass) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 0a5402c89e76..e7f2501e7899 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap => MutableHashMap} import scala.reflect.ClassTag import scala.util.Try @@ -38,7 +38,7 @@ import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { @@ -201,7 +201,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method => val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]] tags should contain allOf ("tag1", "dup", "tag2", "multi word") - tags.filter(!_.isEmpty).size should be (4) + tags.asScala.filter(_.nonEmpty).size should be (4) } appContext.getMaxAppAttempts should be (42) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 128350b64899..5a4ea2ea2f4f 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -21,7 +21,6 @@ import java.io.File import java.net.URL import scala.collection.mutable -import scala.collection.JavaConversions._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} @@ -216,8 +215,8 @@ private object YarnClusterDriver extends Logging with Matchers { assert(listener.driverLogs.nonEmpty) val driverLogs = listener.driverLogs.get assert(driverLogs.size === 2) - assert(driverLogs.containsKey("stderr")) - assert(driverLogs.containsKey("stdout")) + assert(driverLogs.contains("stderr")) + assert(driverLogs.contains("stdout")) val urlStr = driverLogs("stderr") // Ensure that this is a valid URL, else this will throw an exception new URL(urlStr) From 5c08c86bfa43462fb2ca5f7c5980ddfb44dd57f8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 25 Aug 2015 10:22:54 -0700 Subject: [PATCH 0410/1824] [SPARK-10198] [SQL] Turn off partition verification by default Author: Michael Armbrust Closes #8404 from marmbrus/turnOffPartitionVerification. --- .../scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../spark/sql/hive/QueryPartitionSuite.scala | 64 ++++++++++--------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e6f7619519e6..9de75f4c4d08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -312,7 +312,7 @@ private[spark] object SQLConf { doc = "When true, enable filter pushdown for ORC files.") val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", - defaultValue = Some(true), + defaultValue = Some(false), doc = "") val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 017bc2adc103..1cc8a93e8341 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -18,50 +18,54 @@ package org.apache.spark.sql.hive import com.google.common.io.Files +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.util.Utils -class QueryPartitionSuite extends QueryTest { +class QueryPartitionSuite extends QueryTest with SQLTestUtils { private lazy val ctx = org.apache.spark.sql.hive.test.TestHive import ctx.implicits._ - import ctx.sql + + protected def _sqlContext = ctx test("SPARK-5068: query data when path doesn't exist"){ - val testData = ctx.sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") + withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { + val testData = ctx.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - sql("DROP TABLE table_with_partition") - sql("DROP TABLE createAndInsertTest") + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") + } } } From b37f0cc1b4c064d6f09edb161250fa8b783de52a Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 25 Aug 2015 10:54:03 -0700 Subject: [PATCH 0411/1824] [SPARK-8531] [ML] Update ML user guide for MinMaxScaler jira: https://issues.apache.org/jira/browse/SPARK-8531 Update ML user guide for MinMaxScaler Author: Yuhao Yang Author: unknown Closes #7211 from hhbyyh/minmaxdoc. --- docs/ml-features.md | 71 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 642a4b4c5318..62de4838981c 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1133,6 +1133,7 @@ val scaledData = scalerModel.transform(dataFrame) {% highlight java %} import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; @@ -1173,6 +1174,76 @@ scaledData = scalerModel.transform(dataFrame) +## MinMaxScaler + +`MinMaxScaler` transforms a dataset of `Vector` rows, rescaling each feature to a specific range (often [0, 1]). It takes parameters: + +* `min`: 0.0 by default. Lower bound after transformation, shared by all features. +* `max`: 1.0 by default. Upper bound after transformation, shared by all features. + +`MinMaxScaler` computes summary statistics on a data set and produces a `MinMaxScalerModel`. The model can then transform each feature individually such that it is in the given range. + +The rescaled value for a feature E is calculated as, +`\begin{equation} + Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min +\end{equation}` +For the case `E_{max} == E_{min}`, `Rescaled(e_i) = 0.5 * (max + min)` + +Note that since zero values will probably be transformed to non-zero values, output of the transformer will be DenseVector even for sparse input. + +The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [0, 1]. + +
    +
    +More details can be found in the API docs for +[MinMaxScaler](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScaler) and +[MinMaxScalerModel](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel). +{% highlight scala %} +import org.apache.spark.ml.feature.MinMaxScaler +import org.apache.spark.mllib.util.MLUtils + +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +val dataFrame = sqlContext.createDataFrame(data) +val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + +// Compute summary statistics and generate MinMaxScalerModel +val scalerModel = scaler.fit(dataFrame) + +// rescale each feature to range [min, max]. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %} +
    + +
    +More details can be found in the API docs for +[MinMaxScaler](api/java/org/apache/spark/ml/feature/MinMaxScaler.html) and +[MinMaxScalerModel](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html). +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.sql.DataFrame; + +JavaRDD data = + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); +DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + +// Compute summary statistics and generate MinMaxScalerModel +MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + +// rescale each feature to range [min, max]. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %} +
    +
    + ## Bucketizer `Bucketizer` transforms a column of continuous features to a column of feature buckets, where the buckets are specified by users. It takes a parameter: From 881208a8e849facf54166bdd69d3634407f952e7 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 25 Aug 2015 11:58:47 -0700 Subject: [PATCH 0412/1824] [SPARK-10230] [MLLIB] Rename optimizeAlpha to optimizeDocConcentration See [discussion](https://github.com/apache/spark/pull/8254#discussion_r37837770) CC jkbradley Author: Feynman Liang Closes #8422 from feynmanliang/SPARK-10230. --- .../spark/mllib/clustering/LDAOptimizer.scala | 16 ++++++++-------- .../apache/spark/mllib/clustering/LDASuite.scala | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 5c2aae6403be..38486e949bbc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -258,7 +258,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { private var tau0: Double = 1024 private var kappa: Double = 0.51 private var miniBatchFraction: Double = 0.05 - private var optimizeAlpha: Boolean = false + private var optimizeDocConcentration: Boolean = false // internal data structure private var docs: RDD[(Long, Vector)] = null @@ -335,20 +335,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution) - * will be optimized during training. + * Optimize docConcentration, indicates whether docConcentration (Dirichlet parameter for + * document-topic distribution) will be optimized during training. */ @Since("1.5.0") - def getOptimzeAlpha: Boolean = this.optimizeAlpha + def getOptimizeDocConcentration: Boolean = this.optimizeDocConcentration /** - * Sets whether to optimize alpha parameter during training. + * Sets whether to optimize docConcentration parameter during training. * * Default: false */ @Since("1.5.0") - def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = { - this.optimizeAlpha = optimizeAlpha + def setOptimizeDocConcentration(optimizeDocConcentration: Boolean): this.type = { + this.optimizeDocConcentration = optimizeDocConcentration this } @@ -458,7 +458,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { // Note that this is an optimization to avoid batch.count updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) - if (optimizeAlpha) updateAlpha(gammat) + if (optimizeDocConcentration) updateAlpha(gammat) this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 8a714f9b79e0..746a76a7e5fa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -423,7 +423,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val k = 2 val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) - .setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false) + .setGammaShape(100).setOptimizeDocConcentration(true).setSampleWithReplacement(false) val lda = new LDA().setK(k) .setDocConcentration(1D / k) .setTopicConcentration(0.01) From 16a2be1a84c0a274a60c0a584faaf58b55d4942b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 12:16:23 -0700 Subject: [PATCH 0413/1824] [SPARK-10231] [MLLIB] update @Since annotation for mllib.classification Update `Since` annotation in `mllib.classification`: 1. add version to classes, objects, constructors, and public variables declared in constructors 2. correct some versions 3. remove `Since` on `toString` MechCoder dbtsai Author: Xiangrui Meng Closes #8421 from mengxr/SPARK-10231 and squashes the following commits: b2dce80 [Xiangrui Meng] update @Since annotation for mllib.classification --- .../classification/ClassificationModel.scala | 7 +++-- .../classification/LogisticRegression.scala | 20 +++++++++---- .../mllib/classification/NaiveBayes.scala | 28 +++++++++++++++---- .../spark/mllib/classification/SVM.scala | 15 ++++++---- .../StreamingLogisticRegressionWithSGD.scala | 9 +++++- 5 files changed, 58 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index a29b425a71fd..85a413243b04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc. */ @Experimental +@Since("0.8.0") trait ClassificationModel extends Serializable { /** * Predict values for the given data set using the model trained. @@ -37,7 +38,7 @@ trait ClassificationModel extends Serializable { * @param testData RDD representing data points to be predicted * @return an RDD[Double] where each entry contains the corresponding prediction */ - @Since("0.8.0") + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -46,7 +47,7 @@ trait ClassificationModel extends Serializable { * @param testData array representing a single data point * @return predicted category from the trained model */ - @Since("0.8.0") + @Since("1.0.0") def predict(testData: Vector): Double /** @@ -54,7 +55,7 @@ trait ClassificationModel extends Serializable { * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction */ - @Since("0.8.0") + @Since("1.0.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index e03e662227d1..5ceff5b2259e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD * Multinomial Logistic Regression. By default, it is binary logistic regression * so numClasses will be set to 2. */ -class LogisticRegressionModel ( - override val weights: Vector, - override val intercept: Double, - val numFeatures: Int, - val numClasses: Int) +@Since("0.8.0") +class LogisticRegressionModel @Since("1.3.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("1.0.0") override val intercept: Double, + @Since("1.3.0") val numFeatures: Int, + @Since("1.3.0") val numClasses: Int) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable with Saveable with PMMLExportable { @@ -75,6 +76,7 @@ class LogisticRegressionModel ( /** * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. */ + @Since("1.0.0") def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) private var threshold: Option[Double] = Some(0.5) @@ -166,12 +168,12 @@ class LogisticRegressionModel ( override protected def formatVersion: String = "1.0" - @Since("1.4.0") override def toString: String = { s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" } } +@Since("1.3.0") object LogisticRegressionModel extends Loader[LogisticRegressionModel] { @Since("1.3.0") @@ -207,6 +209,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { * for k classes multi-label classification problem. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ +@Since("0.8.0") class LogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -216,6 +219,7 @@ class LogisticRegressionWithSGD private[mllib] ( private val gradient = new LogisticGradient() private val updater = new SquaredL2Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -227,6 +231,7 @@ class LogisticRegressionWithSGD private[mllib] ( * Construct a LogisticRegression object with default parameters: {stepSize: 1.0, * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { @@ -238,6 +243,7 @@ class LogisticRegressionWithSGD private[mllib] ( * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} */ +@Since("0.8.0") object LogisticRegressionWithSGD { // NOTE(shivaram): We use multiple train methods instead of default arguments to support // Java programs. @@ -333,11 +339,13 @@ object LogisticRegressionWithSGD { * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * for k classes multi-label classification problem. */ +@Since("1.1.0") class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { this.setFeatureScaling(true) + @Since("1.1.0") override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) override protected val validators = List(multiLabelValidator) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index dab369207cc9..a956084ae06e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -41,11 +41,12 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * where D is number of features * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" */ +@Since("0.9.0") class NaiveBayesModel private[spark] ( - val labels: Array[Double], - val pi: Array[Double], - val theta: Array[Array[Double]], - val modelType: String) + @Since("1.0.0") val labels: Array[Double], + @Since("0.9.0") val pi: Array[Double], + @Since("0.9.0") val theta: Array[Array[Double]], + @Since("1.4.0") val modelType: String) extends ClassificationModel with Serializable with Saveable { import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes} @@ -83,6 +84,7 @@ class NaiveBayesModel private[spark] ( throw new UnknownError(s"Invalid modelType: $modelType.") } + @Since("1.0.0") override def predict(testData: RDD[Vector]): RDD[Double] = { val bcModel = testData.context.broadcast(this) testData.mapPartitions { iter => @@ -91,6 +93,7 @@ class NaiveBayesModel private[spark] ( } } + @Since("1.0.0") override def predict(testData: Vector): Double = { modelType match { case Multinomial => @@ -107,6 +110,7 @@ class NaiveBayesModel private[spark] ( * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities, * in the same order as class labels */ + @Since("1.5.0") def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = { val bcModel = testData.context.broadcast(this) testData.mapPartitions { iter => @@ -122,6 +126,7 @@ class NaiveBayesModel private[spark] ( * @return predicted posterior class probabilities from the trained model, * in the same order as class labels */ + @Since("1.5.0") def predictProbabilities(testData: Vector): Vector = { modelType match { case Multinomial => @@ -158,6 +163,7 @@ class NaiveBayesModel private[spark] ( new DenseVector(scaledProbs.map(_ / probSum)) } + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) @@ -166,6 +172,7 @@ class NaiveBayesModel private[spark] ( override protected def formatVersion: String = "2.0" } +@Since("1.3.0") object NaiveBayesModel extends Loader[NaiveBayesModel] { import org.apache.spark.mllib.util.Loader._ @@ -199,6 +206,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { dataRDD.write.parquet(dataPath(path)) } + @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { val sqlContext = new SQLContext(sc) // Load Parquet data. @@ -301,30 +309,35 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ - +@Since("0.9.0") class NaiveBayes private ( private var lambda: Double, private var modelType: String) extends Serializable with Logging { import NaiveBayes.{Bernoulli, Multinomial} + @Since("1.4.0") def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) + @Since("0.9.0") def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ + @Since("0.9.0") def setLambda(lambda: Double): NaiveBayes = { this.lambda = lambda this } /** Get the smoothing parameter. */ + @Since("1.4.0") def getLambda: Double = lambda /** * Set the model type using a string (case-sensitive). * Supported options: "multinomial" (default) and "bernoulli". */ + @Since("1.4.0") def setModelType(modelType: String): NaiveBayes = { require(NaiveBayes.supportedModelTypes.contains(modelType), s"NaiveBayes was created with an unknown modelType: $modelType.") @@ -333,6 +346,7 @@ class NaiveBayes private ( } /** Get the model type. */ + @Since("1.4.0") def getModelType: String = this.modelType /** @@ -340,6 +354,7 @@ class NaiveBayes private ( * * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ + @Since("0.9.0") def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { @@ -423,6 +438,7 @@ class NaiveBayes private ( /** * Top-level methods for calling naive Bayes. */ +@Since("0.9.0") object NaiveBayes { /** String name for multinomial model type. */ @@ -485,7 +501,7 @@ object NaiveBayes { * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be * multinomial or bernoulli */ - @Since("0.9.0") + @Since("1.4.0") def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { require(supportedModelTypes.contains(modelType), s"NaiveBayes was created with an unknown modelType: $modelType.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 5f8726986357..896565cd90e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -33,9 +33,10 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -class SVMModel ( - override val weights: Vector, - override val intercept: Double) +@Since("0.8.0") +class SVMModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable with Saveable with PMMLExportable { @@ -47,7 +48,7 @@ class SVMModel ( * with prediction score greater than or equal to this threshold is identified as an positive, * and negative otherwise. The default value is 0.0. */ - @Since("1.3.0") + @Since("1.0.0") @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) @@ -92,12 +93,12 @@ class SVMModel ( override protected def formatVersion: String = "1.0" - @Since("1.4.0") override def toString: String = { s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" } } +@Since("1.3.0") object SVMModel extends Loader[SVMModel] { @Since("1.3.0") @@ -132,6 +133,7 @@ object SVMModel extends Loader[SVMModel] { * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. * NOTE: Labels used in SVM should be {0, 1}. */ +@Since("0.8.0") class SVMWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -141,6 +143,7 @@ class SVMWithSGD private ( private val gradient = new HingeGradient() private val updater = new SquaredL2Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -152,6 +155,7 @@ class SVMWithSGD private ( * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, * regParm: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -162,6 +166,7 @@ class SVMWithSGD private ( /** * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}. */ +@Since("0.8.0") object SVMWithSGD { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index 7d33df3221fb..75630054d136 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.StreamingLinearAlgorithm @@ -44,6 +44,7 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm * }}} */ @Experimental +@Since("1.3.0") class StreamingLogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -58,6 +59,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ + @Since("1.3.0") def this() = this(0.1, 50, 1.0, 0.0) protected val algorithm = new LogisticRegressionWithSGD( @@ -66,30 +68,35 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( protected var model: Option[LogisticRegressionModel] = None /** Set the step size for gradient descent. Default: 0.1. */ + @Since("1.3.0") def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) this } /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + @Since("1.3.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) this } /** Set the fraction of each batch to use for updates. Default: 1.0. */ + @Since("1.3.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this } /** Set the regularization parameter. Default: 0.0. */ + @Since("1.3.0") def setRegParam(regParam: Double): this.type = { this.algorithm.optimizer.setRegParam(regParam) this } /** Set the initial weights. Default: [0.0, 0.0]. */ + @Since("1.3.0") def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this From 71a138cd0e0a14e8426f97877e3b52a562bbd02c Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 25 Aug 2015 13:14:10 -0700 Subject: [PATCH 0414/1824] [SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde. This PR: 1. supports transferring arbitrary nested array from JVM to R side in SerDe; 2. based on 1, collect() implemenation is improved. Now it can support collecting data of complex types from a DataFrame. Author: Sun Rui Closes #8276 from sun-rui/SPARK-10048. --- R/pkg/R/DataFrame.R | 55 +++++++++--- R/pkg/R/deserialize.R | 72 +++++++--------- R/pkg/R/serialize.R | 10 +-- R/pkg/inst/tests/test_Serde.R | 77 +++++++++++++++++ R/pkg/inst/worker/worker.R | 4 +- .../apache/spark/api/r/RBackendHandler.scala | 7 ++ .../scala/org/apache/spark/api/r/SerDe.scala | 86 +++++++++++-------- .../org/apache/spark/sql/api/r/SQLUtils.scala | 32 +------ 8 files changed, 216 insertions(+), 127 deletions(-) create mode 100644 R/pkg/inst/tests/test_Serde.R diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 10f3c4ea5986..ae1d912cf6da 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -652,18 +652,49 @@ setMethod("dim", setMethod("collect", signature(x = "DataFrame"), function(x, stringsAsFactors = FALSE) { - # listCols is a list of raw vectors, one per column - listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) - cols <- lapply(listCols, function(col) { - objRaw <- rawConnection(col) - numRows <- readInt(objRaw) - col <- readCol(objRaw, numRows) - close(objRaw) - col - }) - names(cols) <- columns(x) - do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) - }) + names <- columns(x) + ncol <- length(names) + if (ncol <= 0) { + # empty data.frame with 0 columns and 0 rows + data.frame() + } else { + # listCols is a list of columns + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + stopifnot(length(listCols) == ncol) + + # An empty data.frame with 0 columns and number of rows as collected + nrow <- length(listCols[[1]]) + if (nrow <= 0) { + df <- data.frame() + } else { + df <- data.frame(row.names = 1 : nrow) + } + + # Append columns one by one + for (colIndex in 1 : ncol) { + # Note: appending a column of list type into a data.frame so that + # data of complex type can be held. But getting a cell from a column + # of list type returns a list instead of a vector. So for columns of + # non-complex type, append them as vector. + col <- listCols[[colIndex]] + if (length(col) <= 0) { + df[[names[colIndex]]] <- col + } else { + # TODO: more robust check on column of primitive types + vec <- do.call(c, col) + if (class(vec) != "list") { + df[[names[colIndex]]] <- vec + } else { + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. + df[[names[colIndex]]] <- col + } + } + } + df + } + }) #' Limit #' diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 33bf13ec9e78..6cf628e3007d 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -48,6 +48,7 @@ readTypedObject <- function(con, type) { "r" = readRaw(con), "D" = readDate(con), "t" = readTime(con), + "a" = readArray(con), "l" = readList(con), "n" = NULL, "j" = getJobj(readString(con)), @@ -85,8 +86,7 @@ readTime <- function(con) { as.POSIXct(t, origin = "1970-01-01") } -# We only support lists where all elements are of same type -readList <- function(con) { +readArray <- function(con) { type <- readType(con) len <- readInt(con) if (len > 0) { @@ -100,6 +100,25 @@ readList <- function(con) { } } +# Read a list. Types of each element may be different. +# Null objects are read as NA. +readList <- function(con) { + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + elem <- readObject(con) + if (is.null(elem)) { + elem <- NA + } + l[[i]] <- elem + } + l + } else { + list() + } +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") @@ -132,18 +151,19 @@ readDeserialize <- function(con) { } } -readDeserializeRows <- function(inputCon) { - # readDeserializeRows will deserialize a DataOutputStream composed of - # a list of lists. Since the DOS is one continuous stream and - # the number of rows varies, we put the readRow function in a while loop - # that termintates when the next row is empty. +readMultipleObjects <- function(inputCon) { + # readMultipleObjects will read multiple continuous objects from + # a DataOutputStream. There is no preceding field telling the count + # of the objects, so the number of objects varies, we try to read + # all objects in a loop until the end of the stream. data <- list() while(TRUE) { - row <- readRow(inputCon) - if (length(row) == 0) { + # If reaching the end of the stream, type returned should be "". + type <- readType(inputCon) + if (type == "") { break } - data[[length(data) + 1L]] <- row + data[[length(data) + 1L]] <- readTypedObject(inputCon, type) } data # this is a list of named lists now } @@ -155,35 +175,5 @@ readRowList <- function(obj) { # deserialize the row. rawObj <- rawConnection(obj, "r+") on.exit(close(rawObj)) - readRow(rawObj) -} - -readRow <- function(inputCon) { - numCols <- readInt(inputCon) - if (length(numCols) > 0 && numCols > 0) { - lapply(1:numCols, function(x) { - obj <- readObject(inputCon) - if (is.null(obj)) { - NA - } else { - obj - } - }) # each row is a list now - } else { - list() - } -} - -# Take a single column as Array[Byte] and deserialize it into an atomic vector -readCol <- function(inputCon, numRows) { - if (numRows > 0) { - # sapply can not work with POSIXlt - do.call(c, lapply(1:numRows, function(x) { - value <- readObject(inputCon) - # Replace NULL with NA so we can coerce to vectors - if (is.null(value)) NA else value - })) - } else { - vector() - } + readObject(rawObj) } diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 311021e5d847..e3676f57f907 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -110,18 +110,10 @@ writeRowSerialize <- function(outputCon, rows) { serializeRow <- function(row) { rawObj <- rawConnection(raw(0), "wb") on.exit(close(rawObj)) - writeRow(rawObj, row) + writeGenericList(rawObj, row) rawConnectionValue(rawObj) } -writeRow <- function(con, row) { - numCols <- length(row) - writeInt(con, numCols) - for (i in 1:numCols) { - writeObject(con, row[[i]]) - } -} - writeRaw <- function(con, batch) { writeInt(con, length(batch)) writeBin(batch, con, endian = "big") diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R new file mode 100644 index 000000000000..009db85da2be --- /dev/null +++ b/R/pkg/inst/tests/test_Serde.R @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("SerDe functionality") + +sc <- sparkR.init() + +test_that("SerDe of primitive types", { + x <- callJStatic("SparkRHandler", "echo", 1L) + expect_equal(x, 1L) + expect_equal(class(x), "integer") + + x <- callJStatic("SparkRHandler", "echo", 1) + expect_equal(x, 1) + expect_equal(class(x), "numeric") + + x <- callJStatic("SparkRHandler", "echo", TRUE) + expect_true(x) + expect_equal(class(x), "logical") + + x <- callJStatic("SparkRHandler", "echo", "abc") + expect_equal(x, "abc") + expect_equal(class(x), "character") +}) + +test_that("SerDe of list of primitive types", { + x <- list(1L, 2L, 3L) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "integer") + + x <- list(1, 2, 3) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "numeric") + + x <- list(TRUE, FALSE) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "logical") + + x <- list("a", "b", "c") + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "character") + + # Empty list + x <- list() + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) + +test_that("SerDe of list of lists", { + x <- list(list(1L, 2L, 3L), list(1, 2, 3), + list(TRUE, FALSE), list("a", "b", "c")) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + + # List of empty lists + x <- list(list(), list()) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 7e3b5fc403b2..0c3b0d1f4be2 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -94,7 +94,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- as.list(readLines(inputCon)) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() @@ -120,7 +120,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- readLines(inputCon) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 6ce02e2ea336..bb82f3285f1d 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend) if (objId == "SparkRHandler") { methodName match { + // This function is for test-purpose only + case "echo" => + val args = readArgs(numArgs, dis) + assert(numArgs == 1) + + writeInt(dos, 0) + writeObject(dos, args(0)) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index dbbbcf40c1e9..26ad4f1d4697 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -149,6 +149,10 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'r' => readBytesArr(dis) + case 'l' => { + val len = readInt(dis) + (0 until len).map(_ => readList(dis)).toArray + } case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } @@ -200,6 +204,9 @@ private[spark] object SerDe { case "date" => dos.writeByte('D') case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') + // Array of primitive types + case "array" => dos.writeByte('a') + // Array of objects case "list" => dos.writeByte('l') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") @@ -211,26 +218,35 @@ private[spark] object SerDe { writeType(dos, "void") } else { value.getClass.getName match { + case "java.lang.Character" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[Character].toString) case "java.lang.String" => writeType(dos, "character") writeString(dos, value.asInstanceOf[String]) - case "long" | "java.lang.Long" => + case "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "float" | "java.lang.Float" => + case "java.lang.Float" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "decimal" | "java.math.BigDecimal" => + case "java.math.BigDecimal" => writeType(dos, "double") val javaDecimal = value.asInstanceOf[java.math.BigDecimal] writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) - case "double" | "java.lang.Double" => + case "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) - case "int" | "java.lang.Integer" => + case "java.lang.Byte" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Byte].toInt) + case "java.lang.Short" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Short].toInt) + case "java.lang.Integer" => writeType(dos, "integer") writeInt(dos, value.asInstanceOf[Int]) - case "boolean" | "java.lang.Boolean" => + case "java.lang.Boolean" => writeType(dos, "logical") writeBoolean(dos, value.asInstanceOf[Boolean]) case "java.sql.Date" => @@ -242,43 +258,48 @@ private[spark] object SerDe { case "java.sql.Timestamp" => writeType(dos, "time") writeTime(dos, value.asInstanceOf[Timestamp]) + + // Handle arrays + + // Array of primitive types + + // Special handling for byte array case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) - // TODO: Types not handled right now include - // byte, char, short, float - // Handle arrays - case "[Ljava.lang.String;" => - writeType(dos, "list") - writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[C" => + writeType(dos, "array") + writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString)) + case "[S" => + writeType(dos, "array") + writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt)) case "[I" => - writeType(dos, "list") + writeType(dos, "array") writeIntArr(dos, value.asInstanceOf[Array[Int]]) case "[J" => - writeType(dos, "list") + writeType(dos, "array") writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[F" => + writeType(dos, "array") + writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble)) case "[D" => - writeType(dos, "list") + writeType(dos, "array") writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) case "[Z" => - writeType(dos, "list") + writeType(dos, "array") writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) - case "[[B" => + + // Array of objects, null objects use "void" type + case c if c.startsWith("[") => writeType(dos, "list") - writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) - case otherName => - // Handle array of objects - if (otherName.startsWith("[L")) { - val objArr = value.asInstanceOf[Array[Object]] - writeType(dos, "list") - writeType(dos, "jobj") - dos.writeInt(objArr.length) - objArr.foreach(o => writeJObj(dos, o)) - } else { - writeType(dos, "jobj") - writeJObj(dos, value) - } + val array = value.asInstanceOf[Array[Object]] + writeInt(dos, array.length) + array.foreach(elem => writeObject(dos, elem)) + + case _ => + writeType(dos, "jobj") + writeJObj(dos, value) } } } @@ -350,11 +371,6 @@ private[spark] object SerDe { value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { - writeType(out, "raw") - out.writeInt(value.length) - value.foreach(v => writeBytes(out, v)) - } } private[r] object SerializationFormats { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 92861ab038f1..7f3defec3d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -98,27 +98,17 @@ private[r] object SQLUtils { val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - SerDe.writeInt(dos, row.length) - (0 until row.length).map { idx => - val obj: Object = row(idx).asInstanceOf[Object] - SerDe.writeObject(dos, obj) - } + val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray + SerDe.writeObject(dos, cols) bos.toByteArray() } - def dfToCols(df: DataFrame): Array[Array[Byte]] = { + def dfToCols(df: DataFrame): Array[Array[Any]] = { // localDF is Array[Row] val localDF = df.collect() val numCols = df.columns.length - // dfCols is Array[Array[Any]] - val dfCols = convertRowsToColumns(localDF, numCols) - - dfCols.map { col => - colToRBytes(col) - } - } - def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { + // result is Array[Array[Any]] (0 until numCols).map { colIdx => localDF.map { row => row(colIdx) @@ -126,20 +116,6 @@ private[r] object SQLUtils { }.toArray } - def colToRBytes(col: Array[Any]): Array[Byte] = { - val numRows = col.length - val bos = new ByteArrayOutputStream() - val dos = new DataOutputStream(bos) - - SerDe.writeInt(dos, numRows) - - col.map { item => - val obj: Object = item.asInstanceOf[Object] - SerDe.writeObject(dos, obj) - } - bos.toByteArray() - } - def saveMode(mode: String): SaveMode = { mode match { case "append" => SaveMode.Append From c0e9ff1588b4d9313cc6ec6e00e5c7663eb67910 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 25 Aug 2015 13:21:05 -0700 Subject: [PATCH 0415/1824] [SPARK-9800] Adds docs for GradientDescent$.runMiniBatchSGD alias * Adds doc for alias of runMIniBatchSGD documenting default value for convergeTol * Cleans up a note in code Author: Feynman Liang Closes #8425 from feynmanliang/SPARK-9800. --- .../apache/spark/mllib/optimization/GradientDescent.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 8f0d1e4aa010..3b663b5defb0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -235,7 +235,7 @@ object GradientDescent extends Logging { if (miniBatchSize > 0) { /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * lossSum is computed using the weights from the previous iteration * and regVal is the regularization value computed in the previous iteration as well. */ stochasticLossHistory.append(lossSum / miniBatchSize + regVal) @@ -264,6 +264,9 @@ object GradientDescent extends Logging { } + /** + * Alias of [[runMiniBatchSGD]] with convergenceTol set to default value of 0.001. + */ def runMiniBatchSGD( data: RDD[(Double, Vector)], gradient: Gradient, From c619c7552f22d28cfa321ce671fc9ca854dd655f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 13:22:38 -0700 Subject: [PATCH 0416/1824] [SPARK-10237] [MLLIB] update since versions in mllib.fpm Same as #8421 but for `mllib.fpm`. cc feynmanliang Author: Xiangrui Meng Closes #8429 from mengxr/SPARK-10237. --- .../spark/mllib/fpm/AssociationRules.scala | 7 ++++-- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 9 ++++++-- .../apache/spark/mllib/fpm/PrefixSpan.scala | 23 ++++++++++++++++--- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index ba3b447a8339..95c688c86a7e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -82,12 +82,15 @@ class AssociationRules private[fpm] ( }.filter(_.confidence >= minConfidence) } + /** Java-friendly version of [[run]]. */ + @Since("1.5.0") def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = { val tag = fakeClassTag[Item] run(freqItemsets.rdd)(tag) } } +@Since("1.5.0") object AssociationRules { /** @@ -104,8 +107,8 @@ object AssociationRules { @Since("1.5.0") @Experimental class Rule[Item] private[fpm] ( - val antecedent: Array[Item], - val consequent: Array[Item], + @Since("1.5.0") val antecedent: Array[Item], + @Since("1.5.0") val consequent: Array[Item], freqUnion: Double, freqAntecedent: Double) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index e37f80627168..aea5c4f8a8a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -42,7 +42,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.3.0") @Experimental -class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { +class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced @@ -126,6 +127,8 @@ class FPGrowth private ( new FPGrowthModel(freqItemsets) } + /** Java-friendly version of [[run]]. */ + @Since("1.3.0") def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { implicit val tag = fakeClassTag[Item] run(data.rdd.map(_.asScala.toArray)) @@ -226,7 +229,9 @@ object FPGrowth { * */ @Since("1.3.0") - class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + class FreqItemset[Item] @Since("1.3.0") ( + @Since("1.3.0") val items: Array[Item], + @Since("1.3.0") val freq: Long) extends Serializable { /** * Returns items in a Java List. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index dc4ae1d0b69e..97916daa2e9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.rdd.RDD @@ -51,6 +51,7 @@ import org.apache.spark.storage.StorageLevel * (Wikipedia)]] */ @Experimental +@Since("1.5.0") class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int, @@ -61,17 +62,20 @@ class PrefixSpan private ( * Constructs a default instance with default parameters * {minSupport: `0.1`, maxPatternLength: `10`, maxLocalProjDBSize: `32000000L`}. */ + @Since("1.5.0") def this() = this(0.1, 10, 32000000L) /** * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered * frequent). */ + @Since("1.5.0") def getMinSupport: Double = minSupport /** * Sets the minimal support level (default: `0.1`). */ + @Since("1.5.0") def setMinSupport(minSupport: Double): this.type = { require(minSupport >= 0 && minSupport <= 1, s"The minimum support value must be in [0, 1], but got $minSupport.") @@ -82,11 +86,13 @@ class PrefixSpan private ( /** * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. */ + @Since("1.5.0") def getMaxPatternLength: Int = maxPatternLength /** * Sets maximal pattern length (default: `10`). */ + @Since("1.5.0") def setMaxPatternLength(maxPatternLength: Int): this.type = { // TODO: support unbounded pattern length when maxPatternLength = 0 require(maxPatternLength >= 1, @@ -98,12 +104,14 @@ class PrefixSpan private ( /** * Gets the maximum number of items allowed in a projected database before local processing. */ + @Since("1.5.0") def getMaxLocalProjDBSize: Long = maxLocalProjDBSize /** * Sets the maximum number of items (including delimiters used in the internal storage format) * allowed in a projected database before local processing (default: `32000000L`). */ + @Since("1.5.0") def setMaxLocalProjDBSize(maxLocalProjDBSize: Long): this.type = { require(maxLocalProjDBSize >= 0L, s"The maximum local projected database size must be nonnegative, but got $maxLocalProjDBSize") @@ -116,6 +124,7 @@ class PrefixSpan private ( * @param data sequences of itemsets. * @return a [[PrefixSpanModel]] that contains the frequent patterns */ + @Since("1.5.0") def run[Item: ClassTag](data: RDD[Array[Array[Item]]]): PrefixSpanModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") @@ -202,6 +211,7 @@ class PrefixSpan private ( * @tparam Sequence sequence type, which is an Iterable of Itemsets * @return a [[PrefixSpanModel]] that contains the frequent sequential patterns */ + @Since("1.5.0") def run[Item, Itemset <: jl.Iterable[Item], Sequence <: jl.Iterable[Itemset]]( data: JavaRDD[Sequence]): PrefixSpanModel[Item] = { implicit val tag = fakeClassTag[Item] @@ -211,6 +221,7 @@ class PrefixSpan private ( } @Experimental +@Since("1.5.0") object PrefixSpan extends Logging { /** @@ -535,10 +546,14 @@ object PrefixSpan extends Logging { * @param freq frequency * @tparam Item item type */ - class FreqSequence[Item](val sequence: Array[Array[Item]], val freq: Long) extends Serializable { + @Since("1.5.0") + class FreqSequence[Item] @Since("1.5.0") ( + @Since("1.5.0") val sequence: Array[Array[Item]], + @Since("1.5.0") val freq: Long) extends Serializable { /** * Returns sequence as a Java List of lists for Java users. */ + @Since("1.5.0") def javaSequence: ju.List[ju.List[Item]] = sequence.map(_.toList.asJava).toList.asJava } } @@ -548,5 +563,7 @@ object PrefixSpan extends Logging { * @param freqSequences frequent sequences * @tparam Item item type */ -class PrefixSpanModel[Item](val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) +@Since("1.5.0") +class PrefixSpanModel[Item] @Since("1.5.0") ( + @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) extends Serializable From 9205907876cf65695e56c2a94bedd83df3675c03 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 25 Aug 2015 13:23:15 -0700 Subject: [PATCH 0417/1824] [SPARK-9797] [MLLIB] [DOC] StreamingLinearRegressionWithSGD.setConvergenceTol default value Adds default convergence tolerance (0.001, set in `GradientDescent.convergenceTol`) to `setConvergenceTol`'s scaladoc Author: Feynman Liang Closes #8424 from feynmanliang/SPARK-9797. --- .../mllib/regression/StreamingLinearRegressionWithSGD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 537a05274eec..26654e4a0683 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -93,7 +93,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( } /** - * Set the convergence tolerance. + * Set the convergence tolerance. Default: 0.001. */ def setConvergenceTol(tolerance: Double): this.type = { this.algorithm.optimizer.setConvergenceTol(tolerance) From 00ae4be97f7b205432db2967ba6d506286ef2ca6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 14:11:38 -0700 Subject: [PATCH 0418/1824] [SPARK-10239] [SPARK-10244] [MLLIB] update since versions in mllib.pmml and mllib.util Same as #8421 but for `mllib.pmml` and `mllib.util`. cc dbtsai Author: Xiangrui Meng Closes #8430 from mengxr/SPARK-10239 and squashes the following commits: a189acf [Xiangrui Meng] update since versions in mllib.pmml and mllib.util --- .../org/apache/spark/mllib/pmml/PMMLExportable.scala | 7 ++++++- .../org/apache/spark/mllib/util/DataValidators.scala | 7 +++++-- .../apache/spark/mllib/util/KMeansDataGenerator.scala | 5 ++++- .../apache/spark/mllib/util/LinearDataGenerator.scala | 10 ++++++++-- .../mllib/util/LogisticRegressionDataGenerator.scala | 5 ++++- .../org/apache/spark/mllib/util/MFDataGenerator.scala | 4 +++- .../scala/org/apache/spark/mllib/util/MLUtils.scala | 2 ++ .../org/apache/spark/mllib/util/SVMDataGenerator.scala | 6 ++++-- .../org/apache/spark/mllib/util/modelSaveLoad.scala | 6 +++++- 9 files changed, 41 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 5e882d4ebb10..274ac7c99553 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -23,7 +23,7 @@ import javax.xml.transform.stream.StreamResult import org.jpmml.model.JAXBUtil import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory /** @@ -33,6 +33,7 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory * developed by the Data Mining Group (www.dmg.org). */ @DeveloperApi +@Since("1.4.0") trait PMMLExportable { /** @@ -48,6 +49,7 @@ trait PMMLExportable { * Export the model to a local file in PMML format */ @Experimental + @Since("1.4.0") def toPMML(localPath: String): Unit = { toPMML(new StreamResult(new File(localPath))) } @@ -57,6 +59,7 @@ trait PMMLExportable { * Export the model to a directory on a distributed file system in PMML format */ @Experimental + @Since("1.4.0") def toPMML(sc: SparkContext, path: String): Unit = { val pmml = toPMML() sc.parallelize(Array(pmml), 1).saveAsTextFile(path) @@ -67,6 +70,7 @@ trait PMMLExportable { * Export the model to the OutputStream in PMML format */ @Experimental + @Since("1.4.0") def toPMML(outputStream: OutputStream): Unit = { toPMML(new StreamResult(outputStream)) } @@ -76,6 +80,7 @@ trait PMMLExportable { * Export the model to a String in PMML format */ @Experimental + @Since("1.4.0") def toPMML(): String = { val writer = new StringWriter toPMML(new StreamResult(writer)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala index be335a1aca58..dffe6e78939e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -17,16 +17,17 @@ package org.apache.spark.mllib.util -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: * A collection of methods used to validate data before applying ML algorithms. */ @DeveloperApi +@Since("0.8.0") object DataValidators extends Logging { /** @@ -34,6 +35,7 @@ object DataValidators extends Logging { * * @return True if labels are all zero or one, false otherwise. */ + @Since("1.0.0") val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() if (numInvalid != 0) { @@ -48,6 +50,7 @@ object DataValidators extends Logging { * * @return True if labels are all in the range of {0, 1, ..., k-1}, false otherwise. */ + @Since("1.3.0") def multiLabelValidator(k: Int): RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala index e6bcff48b022..00fd1606a369 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.rdd.RDD /** @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * cluster with scale 1 around each center. */ @DeveloperApi +@Since("0.8.0") object KMeansDataGenerator { /** @@ -42,6 +43,7 @@ object KMeansDataGenerator { * @param r Scaling factor for the distribution of the initial centers * @param numPartitions Number of partitions of the generated RDD; default 2 */ + @Since("0.8.0") def generateKMeansRDD( sc: SparkContext, numPoints: Int, @@ -62,6 +64,7 @@ object KMeansDataGenerator { } } + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 6) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 7a1c7796065e..d0ba454f379a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -22,11 +22,11 @@ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -35,6 +35,7 @@ import org.apache.spark.mllib.regression.LabeledPoint * response variable `Y`. */ @DeveloperApi +@Since("0.8.0") object LinearDataGenerator { /** @@ -46,6 +47,7 @@ object LinearDataGenerator { * @param seed Random seed * @return Java List of input. */ + @Since("0.8.0") def generateLinearInputAsList( intercept: Double, weights: Array[Double], @@ -68,6 +70,7 @@ object LinearDataGenerator { * @param eps Epsilon scaling factor. * @return Seq of input. */ + @Since("0.8.0") def generateLinearInput( intercept: Double, weights: Array[Double], @@ -92,6 +95,7 @@ object LinearDataGenerator { * @param eps Epsilon scaling factor. * @return Seq of input. */ + @Since("0.8.0") def generateLinearInput( intercept: Double, weights: Array[Double], @@ -132,6 +136,7 @@ object LinearDataGenerator { * * @return RDD of LabeledPoint containing sample data. */ + @Since("0.8.0") def generateLinearRDD( sc: SparkContext, nexamples: Int, @@ -151,6 +156,7 @@ object LinearDataGenerator { data } + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index c09cbe69bb97..33477ee20ebb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint @@ -31,6 +31,7 @@ import org.apache.spark.mllib.linalg.Vectors * with probability `probOne` and scales features for positive examples by `eps`. */ @DeveloperApi +@Since("0.8.0") object LogisticRegressionDataGenerator { /** @@ -43,6 +44,7 @@ object LogisticRegressionDataGenerator { * @param nparts Number of partitions of the generated RDD. Default value is 2. * @param probOne Probability that a label is 1 (and not 0). Default value is 0.5. */ + @Since("0.8.0") def generateLogisticRDD( sc: SparkContext, nexamples: Int, @@ -62,6 +64,7 @@ object LogisticRegressionDataGenerator { data } + @Since("0.8.0") def main(args: Array[String]) { if (args.length != 5) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 16f430599a51..906bd30563bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -23,7 +23,7 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.SparkContext -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD @@ -52,7 +52,9 @@ import org.apache.spark.rdd.RDD * testSampFact (Double) Percentage of training data to use as test data. */ @DeveloperApi +@Since("0.8.0") object MFDataGenerator { + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4940974bf4f4..81c2f0ce6e12 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -36,6 +36,7 @@ import org.apache.spark.streaming.dstream.DStream /** * Helper methods to load, save and pre-process data used in ML Lib. */ +@Since("0.8.0") object MLUtils { private[mllib] lazy val EPSILON = { @@ -168,6 +169,7 @@ object MLUtils { * * @see [[org.apache.spark.mllib.util.MLUtils#loadLibSVMFile]] */ + @Since("1.0.0") def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String) { // TODO: allow to specify label precision and feature precision. val dataStr = data.map { case LabeledPoint(label, features) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index ad20b7694a77..cde597939617 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -21,11 +21,11 @@ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -33,8 +33,10 @@ import org.apache.spark.mllib.regression.LabeledPoint * for the features and adds Gaussian noise with weight 0.1 to generate labels. */ @DeveloperApi +@Since("0.8.0") object SVMDataGenerator { + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala index 30d642c754b7..4d71d534a077 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -24,7 +24,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -35,6 +35,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * This should be inherited by the class which implements model instances. */ @DeveloperApi +@Since("1.3.0") trait Saveable { /** @@ -50,6 +51,7 @@ trait Saveable { * @param path Path specifying the directory in which to save this model. * If the directory already exists, this method throws an exception. */ + @Since("1.3.0") def save(sc: SparkContext, path: String): Unit /** Current version of model save/load format. */ @@ -64,6 +66,7 @@ trait Saveable { * This should be inherited by an object paired with the model class. */ @DeveloperApi +@Since("1.3.0") trait Loader[M <: Saveable] { /** @@ -75,6 +78,7 @@ trait Loader[M <: Saveable] { * @param path Path specifying the directory to which the model was saved. * @return Model instance */ + @Since("1.3.0") def load(sc: SparkContext, path: String): M } From ec89bd840a6862751999d612f586a962cae63f6d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 25 Aug 2015 14:55:34 -0700 Subject: [PATCH 0419/1824] [SPARK-10245] [SQL] Fix decimal literals with precision < scale In BigDecimal or java.math.BigDecimal, the precision could be smaller than scale, for example, BigDecimal("0.001") has precision = 1 and scale = 3. But DecimalType require that the precision should be larger than scale, so we should use the maximum of precision and scale when inferring the schema from decimal literal. Author: Davies Liu Closes #8428 from davies/smaller_decimal. --- .../spark/sql/catalyst/expressions/literals.scala | 7 ++++--- .../catalyst/expressions/LiteralExpressionSuite.scala | 8 +++++--- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 34bad23802ba..8c0c5d5b1e31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -36,9 +36,10 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType(d.precision, d.scale)) - case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType(d.precision(), d.scale())) - case d: Decimal => Literal(d, DecimalType(d.precision, d.scale)) + case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) + case d: java.math.BigDecimal => + Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) + case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index f6404d21611e..015eb1897fb8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -83,12 +83,14 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("decimal") { - List(0.0, 1.2, 1.1111, 5).foreach { d => + List(-0.0001, 0.0, 0.001, 1.2, 1.1111, 5).foreach { d => checkEvaluation(Literal(Decimal(d)), Decimal(d)) checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt)) checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong)) - checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)), - Decimal((d * 1000L).toLong, 10, 1)) + checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 3)), + Decimal((d * 1000L).toLong, 10, 3)) + checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d)) + checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dcb4e8371098..aa07665c6b70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1627,6 +1627,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(null)) } + test("precision smaller than scale") { + checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) + checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00"))) + checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10"))) + checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01"))) + checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001"))) + checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01"))) + checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) + } + test("external sorting updates peak execution memory") { withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { val sc = sqlContext.sparkContext From 7467b52ed07f174d93dfc4cb544dc4b69a2c2826 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 25 Aug 2015 15:19:41 -0700 Subject: [PATCH 0420/1824] [SPARK-10215] [SQL] Fix precision of division (follow the rule in Hive) Follow the rule in Hive for decimal division. see https://github.com/apache/hive/blob/ac755ebe26361a4647d53db2a28500f71697b276/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java#L113 cc chenghao-intel Author: Davies Liu Closes #8415 from davies/decimal_div2. --- .../catalyst/analysis/HiveTypeCoercion.scala | 10 ++++++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 9 +++---- .../analysis/DecimalPrecisionSuite.scala | 8 +++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 25 +++++++++++++++++-- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a1aa2a2b2c68..87c11abbad49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -396,8 +396,14 @@ object HiveTypeCoercion { resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), - max(6, s1 + p2 + 1)) + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + val resultType = DecimalType.bounded(intDig + decDig, decDig) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1e0cc81dae97..820b336aac75 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ class AnalysisSuite extends AnalysisTest { - import TestRelations._ + import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { val plan = (1 to 100) @@ -96,7 +95,7 @@ class AnalysisSuite extends AnalysisTest { assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) // StringType will be promoted into Decimal(38, 18) - assert(pl(3).dataType == DecimalType(38, 29)) + assert(pl(3).dataType == DecimalType(38, 22)) assert(pl(4).dataType == DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index fc11627da6fd..b4ad618c23e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,10 +136,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkType(Multiply(i, u), DecimalType(38, 18)) checkType(Multiply(u, u), DecimalType(38, 36)) - checkType(Divide(u, d1), DecimalType(38, 21)) - checkType(Divide(u, d2), DecimalType(38, 24)) - checkType(Divide(u, i), DecimalType(38, 29)) - checkType(Divide(u, u), DecimalType(38, 38)) + checkType(Divide(u, d1), DecimalType(38, 18)) + checkType(Divide(u, d2), DecimalType(38, 19)) + checkType(Divide(u, i), DecimalType(38, 23)) + checkType(Divide(u, u), DecimalType(38, 18)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index aa07665c6b70..9e172b2c264c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1622,9 +1622,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333333333333", new MathContext(38)))) + Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(null)) + Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) + } + + test("SPARK-10215 Div of Decimal returns null") { + val d = Decimal(1.12321) + val df = Seq((d, 1)).toDF("a", "b") + + checkAnswer( + df.selectExpr("b * a / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a / b / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a + b"), + Seq(Row(BigDecimal(2.12321)))) + checkAnswer( + df.selectExpr("b * a - b"), + Seq(Row(BigDecimal(0.12321)))) + checkAnswer( + df.selectExpr("b * a * b"), + Seq(Row(d.toBigDecimal))) } test("precision smaller than scale") { From 125205cdb35530cdb4a8fff3e1ee49cf4a299583 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 25 Aug 2015 17:39:20 -0700 Subject: [PATCH 0421/1824] [SPARK-9888] [MLLIB] User guide for new LDA features * Adds two new sections to LDA's user guide; one for each optimizer/model * Documents new features added to LDA (e.g. topXXXperXXX, asymmetric priors, hyperpam optimization) * Cleans up a TODO and sets a default parameter in LDA code jkbradley hhbyyh Author: Feynman Liang Closes #8254 from feynmanliang/SPARK-9888. --- docs/mllib-clustering.md | 135 +++++++++++++++--- .../spark/mllib/clustering/LDAModel.scala | 1 - .../spark/mllib/clustering/LDASuite.scala | 1 + 3 files changed, 117 insertions(+), 20 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index fd9ab258e196..3fb35d3c50b0 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -438,28 +438,125 @@ sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") is a topic model which infers topics from a collection of text documents. LDA can be thought of as a clustering algorithm as follows: -* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. -* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. -* Rather than estimating a clustering using a traditional distance, LDA uses a function based - on a statistical model of how text documents are generated. - -LDA takes in a collection of documents as vectors of word counts. -It supports different inference algorithms via `setOptimizer` function. EMLDAOptimizer learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) -on the likelihood function and yields comprehensive results, while OnlineLDAOptimizer uses iterative mini-batch sampling for [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) and is generally memory friendly. After fitting on the documents, LDA provides: - -* Topics: Inferred topics, each of which is a probability distribution over terms (words). -* Topic distributions for documents: For each non empty document in the training set, LDA gives a probability distribution over topics. (EM only). Note that for empty documents, we don't create the topic distributions. (EM only) +* Topics correspond to cluster centers, and documents correspond to +examples (rows) in a dataset. +* Topics and documents both exist in a feature space, where feature +vectors are vectors of word counts (bag of words). +* Rather than estimating a clustering using a traditional distance, LDA +uses a function based on a statistical model of how text documents are +generated. + +LDA supports different inference algorithms via `setOptimizer` function. +`EMLDAOptimizer` learns clustering using +[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +on the likelihood function and yields comprehensive results, while +`OnlineLDAOptimizer` uses iterative mini-batch sampling for [online +variational +inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) +and is generally memory friendly. -LDA takes the following parameters: +LDA takes in a collection of documents as vectors of word counts and the +following parameters (set using the builder pattern): * `k`: Number of topics (i.e., cluster centers) -* `maxIterations`: Limit on the number of iterations of EM used for learning -* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. -* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. -* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. - -*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet -support prediction on new documents, and it does not have a Python API. These will be added in the future. +* `optimizer`: Optimizer to use for learning the LDA model, either +`EMLDAOptimizer` or `OnlineLDAOptimizer` +* `docConcentration`: Dirichlet parameter for prior over documents' +distributions over topics. Larger values encourage smoother inferred +distributions. +* `topicConcentration`: Dirichlet parameter for prior over topics' +distributions over terms (words). Larger values encourage smoother +inferred distributions. +* `maxIterations`: Limit on the number of iterations. +* `checkpointInterval`: If using checkpointing (set in the Spark +configuration), this parameter specifies the frequency with which +checkpoints will be created. If `maxIterations` is large, using +checkpointing can help reduce shuffle file sizes on disk and help with +failure recovery. + + +All of MLlib's LDA models support: + +* `describeTopics`: Returns topics as arrays of most important terms and +term weights +* `topicsMatrix`: Returns a `vocabSize` by `k` matrix where each column +is a topic + +*Note*: LDA is still an experimental feature under active development. +As a result, certain features are only available in one of the two +optimizers / models generated by the optimizer. Currently, a distributed +model can be converted into a local model, but not vice-versa. + +The following discussion will describe each optimizer/model pair +separately. + +**Expectation Maximization** + +Implemented in +[`EMLDAOptimizer`](api/scala/index.html#org.apache.spark.mllib.clustering.EMLDAOptimizer) +and +[`DistributedLDAModel`](api/scala/index.html#org.apache.spark.mllib.clustering.DistributedLDAModel). + +For the parameters provided to `LDA`: + +* `docConcentration`: Only symmetric priors are supported, so all values +in the provided `k`-dimensional vector must be identical. All values +must also be $> 1.0$. Providing `Vector(-1)` results in default behavior +(uniform `k` dimensional vector with value $(50 / k) + 1$ +* `topicConcentration`: Only symmetric priors supported. Values must be +$> 1.0$. Providing `-1` results in defaulting to a value of $0.1 + 1$. +* `maxIterations`: The maximum number of EM iterations. + +`EMLDAOptimizer` produces a `DistributedLDAModel`, which stores not only +the inferred topics but also the full training corpus and topic +distributions for each document in the training corpus. A +`DistributedLDAModel` supports: + + * `topTopicsPerDocument`: The top topics and their weights for + each document in the training corpus + * `topDocumentsPerTopic`: The top documents for each topic and + the corresponding weight of the topic in the documents. + * `logPrior`: log probability of the estimated topics and + document-topic distributions given the hyperparameters + `docConcentration` and `topicConcentration` + * `logLikelihood`: log likelihood of the training corpus, given the + inferred topics and document-topic distributions + +**Online Variational Bayes** + +Implemented in +[`OnlineLDAOptimizer`](api/scala/org/apache/spark/mllib/clustering/OnlineLDAOptimizer.html) +and +[`LocalLDAModel`](api/scala/org/apache/spark/mllib/clustering/LocalLDAModel.html). + +For the parameters provided to `LDA`: + +* `docConcentration`: Asymmetric priors can be used by passing in a +vector with values equal to the Dirichlet parameter in each of the `k` +dimensions. Values should be $>= 0$. Providing `Vector(-1)` results in +default behavior (uniform `k` dimensional vector with value $(1.0 / k)$) +* `topicConcentration`: Only symmetric priors supported. Values must be +$>= 0$. Providing `-1` results in defaulting to a value of $(1.0 / k)$. +* `maxIterations`: Maximum number of minibatches to submit. + +In addition, `OnlineLDAOptimizer` accepts the following parameters: + +* `miniBatchFraction`: Fraction of corpus sampled and used at each +iteration +* `optimizeDocConcentration`: If set to true, performs maximum-likelihood +estimation of the hyperparameter `docConcentration` (aka `alpha`) +after each minibatch and sets the optimized `docConcentration` in the +returned `LocalLDAModel` +* `tau0` and `kappa`: Used for learning-rate decay, which is computed by +$(\tau_0 + iter)^{-\kappa}$ where $iter$ is the current number of iterations. + +`OnlineLDAOptimizer` produces a `LocalLDAModel`, which only stores the +inferred topics. A `LocalLDAModel` supports: + +* `logLikelihood(documents)`: Calculates a lower bound on the provided +`documents` given the inferred topics. +* `logPerplexity(documents)`: Calculates an upper bound on the +perplexity of the provided `documents` given the inferred topics. **Examples** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 667374a2bc41..432bbedc8d6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -435,7 +435,6 @@ object LocalLDAModel extends Loader[LocalLDAModel] { } val topicsMat = Matrices.fromBreeze(brzTopics) - // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940 new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 746a76a7e5fa..37fb69d68f6b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -68,6 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Train a model val lda = new LDA() lda.setK(k) + .setOptimizer(new EMLDAOptimizer) .setDocConcentration(topicSmoothing) .setTopicConcentration(termSmoothing) .setMaxIterations(5) From 8668ead2e7097b9591069599fbfccf67c53db659 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 18:17:54 -0700 Subject: [PATCH 0422/1824] [SPARK-10233] [MLLIB] update since version in mllib.evaluation Same as #8421 but for `mllib.evaluation`. cc avulanov Author: Xiangrui Meng Closes #8423 from mengxr/SPARK-10233. --- .../evaluation/BinaryClassificationMetrics.scala | 8 ++++---- .../spark/mllib/evaluation/MulticlassMetrics.scala | 11 ++++++++++- .../spark/mllib/evaluation/MultilabelMetrics.scala | 12 +++++++++++- .../spark/mllib/evaluation/RegressionMetrics.scala | 3 ++- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 76ae847921f4..508fe532b130 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -42,11 +42,11 @@ import org.apache.spark.sql.DataFrame * be smaller as a result, meaning there may be an extra sample at * partition boundaries. */ -@Since("1.3.0") +@Since("1.0.0") @Experimental -class BinaryClassificationMetrics( - val scoreAndLabels: RDD[(Double, Double)], - val numBins: Int) extends Logging { +class BinaryClassificationMetrics @Since("1.3.0") ( + @Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)], + @Since("1.3.0") val numBins: Int) extends Logging { require(numBins >= 0, "numBins must be nonnegative") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 02e89d921033..00e837661dfc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.DataFrame */ @Since("1.1.0") @Experimental -class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { +class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { /** * An auxiliary constructor taking a DataFrame. @@ -140,6 +140,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns precision */ + @Since("1.1.0") lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount /** @@ -148,23 +149,27 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * because sum of all false positives is equal to sum * of all false negatives) */ + @Since("1.1.0") lazy val recall: Double = precision /** * Returns f-measure * (equals to precision and recall because precision equals recall) */ + @Since("1.1.0") lazy val fMeasure: Double = precision /** * Returns weighted true positive rate * (equals to precision, recall and f-measure) */ + @Since("1.1.0") lazy val weightedTruePositiveRate: Double = weightedRecall /** * Returns weighted false positive rate */ + @Since("1.1.0") lazy val weightedFalsePositiveRate: Double = labelCountByClass.map { case (category, count) => falsePositiveRate(category) * count.toDouble / labelCount }.sum @@ -173,6 +178,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns weighted averaged recall * (equals to precision, recall and f-measure) */ + @Since("1.1.0") lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) => recall(category) * count.toDouble / labelCount }.sum @@ -180,6 +186,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged precision */ + @Since("1.1.0") lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) => precision(category) * count.toDouble / labelCount }.sum @@ -196,6 +203,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged f1-measure */ + @Since("1.1.0") lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) => fMeasure(category, 1.0) * count.toDouble / labelCount }.sum @@ -203,5 +211,6 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns the sequence of labels in ascending order */ + @Since("1.1.0") lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index a0a8d9c56847..c100b3c9ec14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.DataFrame * both are non-null Arrays, each with unique elements. */ @Since("1.2.0") -class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { +class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double], Array[Double])]) { /** * An auxiliary constructor taking a DataFrame. @@ -46,6 +46,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns subset accuracy * (for equal sets of labels) */ + @Since("1.2.0") lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) => predictions.deep == labels.deep }.count().toDouble / numDocs @@ -53,6 +54,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns accuracy */ + @Since("1.2.0") lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs @@ -61,6 +63,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns Hamming-loss */ + @Since("1.2.0") lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) => labels.size + predictions.size - 2 * labels.intersect(predictions).size }.sum / (numDocs * numLabels) @@ -68,6 +71,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based precision averaged by the number of documents */ + @Since("1.2.0") lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) => if (predictions.size > 0) { predictions.intersect(labels).size.toDouble / predictions.size @@ -79,6 +83,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based recall averaged by the number of documents */ + @Since("1.2.0") lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / labels.size }.sum / numDocs @@ -86,6 +91,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based f1-measure averaged by the number of documents */ + @Since("1.2.0") lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) => 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size) }.sum / numDocs @@ -143,6 +149,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based precision * (equals to micro-averaged document-based precision) */ + @Since("1.2.0") lazy val microPrecision: Double = { val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} sumTp.toDouble / (sumTp + sumFp) @@ -152,6 +159,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based recall * (equals to micro-averaged document-based recall) */ + @Since("1.2.0") lazy val microRecall: Double = { val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} sumTp.toDouble / (sumTp + sumFn) @@ -161,10 +169,12 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based f1-measure * (equals to micro-averaged document-based f1-measure) */ + @Since("1.2.0") lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) /** * Returns the sequence of labels in ascending order */ + @Since("1.2.0") lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 36a6c357c389..799ebb980ef0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -32,7 +32,8 @@ import org.apache.spark.sql.DataFrame */ @Since("1.2.0") @Experimental -class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { +class RegressionMetrics @Since("1.2.0") ( + predictionAndObservations: RDD[(Double, Double)]) extends Logging { /** * An auxiliary constructor taking a DataFrame. From ab431f8a970b85fba34ccb506c0f8815e55c63bf Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 20:07:56 -0700 Subject: [PATCH 0423/1824] [SPARK-10238] [MLLIB] update since versions in mllib.linalg Same as #8421 but for `mllib.linalg`. cc dbtsai Author: Xiangrui Meng Closes #8440 from mengxr/SPARK-10238 and squashes the following commits: b38437e [Xiangrui Meng] update since versions in mllib.linalg --- .../apache/spark/mllib/linalg/Matrices.scala | 44 ++++++++++++------- .../linalg/SingularValueDecomposition.scala | 1 + .../apache/spark/mllib/linalg/Vectors.scala | 25 ++++++++--- .../linalg/distributed/BlockMatrix.scala | 10 +++-- .../linalg/distributed/CoordinateMatrix.scala | 4 +- .../distributed/DistributedMatrix.scala | 2 + .../linalg/distributed/IndexedRowMatrix.scala | 4 +- .../mllib/linalg/distributed/RowMatrix.scala | 5 ++- 8 files changed, 64 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 28b5b4637bf1..c02ba426fcc3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -32,18 +32,23 @@ import org.apache.spark.sql.types._ * Trait for a local matrix. */ @SQLUserDefinedType(udt = classOf[MatrixUDT]) +@Since("1.0.0") sealed trait Matrix extends Serializable { /** Number of rows. */ + @Since("1.0.0") def numRows: Int /** Number of columns. */ + @Since("1.0.0") def numCols: Int /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + @Since("1.3.0") val isTransposed: Boolean = false /** Converts to a dense array in column major. */ + @Since("1.0.0") def toArray: Array[Double] = { val newArray = new Array[Double](numRows * numCols) foreachActive { (i, j, v) => @@ -56,6 +61,7 @@ sealed trait Matrix extends Serializable { private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ + @Since("1.3.0") def apply(i: Int, j: Int): Double /** Return the index for the (i, j)-th element in the backing array. */ @@ -65,12 +71,15 @@ sealed trait Matrix extends Serializable { private[mllib] def update(i: Int, j: Int, v: Double): Unit /** Get a deep copy of the matrix. */ + @Since("1.2.0") def copy: Matrix /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + @Since("1.3.0") def transpose: Matrix /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ + @Since("1.2.0") def multiply(y: DenseMatrix): DenseMatrix = { val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) BLAS.gemm(1.0, this, y, 0.0, C) @@ -78,11 +87,13 @@ sealed trait Matrix extends Serializable { } /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */ + @Since("1.2.0") def multiply(y: DenseVector): DenseVector = { multiply(y.asInstanceOf[Vector]) } /** Convenience method for `Matrix`-`Vector` multiplication. */ + @Since("1.4.0") def multiply(y: Vector): DenseVector = { val output = new DenseVector(new Array[Double](numRows)) BLAS.gemv(1.0, this, y, 0.0, output) @@ -93,6 +104,7 @@ sealed trait Matrix extends Serializable { override def toString: String = toBreeze.toString() /** A human readable representation of the matrix with maximum lines and width */ + @Since("1.4.0") def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) /** Map the values of this matrix using a function. Generates a new matrix. Performs the @@ -118,11 +130,13 @@ sealed trait Matrix extends Serializable { /** * Find the number of non-zero active values. */ + @Since("1.5.0") def numNonzeros: Int /** * Find the number of values stored explicitly. These values can be zero as well. */ + @Since("1.5.0") def numActives: Int } @@ -230,11 +244,11 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { */ @Since("1.0.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) -class DenseMatrix( - val numRows: Int, - val numCols: Int, - val values: Array[Double], - override val isTransposed: Boolean) extends Matrix { +class DenseMatrix @Since("1.3.0") ( + @Since("1.0.0") val numRows: Int, + @Since("1.0.0") val numCols: Int, + @Since("1.0.0") val values: Array[Double], + @Since("1.3.0") override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") @@ -254,7 +268,7 @@ class DenseMatrix( * @param numCols number of columns * @param values matrix entries in column major */ - @Since("1.3.0") + @Since("1.0.0") def this(numRows: Int, numCols: Int, values: Array[Double]) = this(numRows, numCols, values, false) @@ -491,13 +505,13 @@ object DenseMatrix { */ @Since("1.2.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) -class SparseMatrix( - val numRows: Int, - val numCols: Int, - val colPtrs: Array[Int], - val rowIndices: Array[Int], - val values: Array[Double], - override val isTransposed: Boolean) extends Matrix { +class SparseMatrix @Since("1.3.0") ( + @Since("1.2.0") val numRows: Int, + @Since("1.2.0") val numCols: Int, + @Since("1.2.0") val colPtrs: Array[Int], + @Since("1.2.0") val rowIndices: Array[Int], + @Since("1.2.0") val values: Array[Double], + @Since("1.3.0") override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") @@ -527,7 +541,7 @@ class SparseMatrix( * order for each column * @param values non-zero matrix entries in column major */ - @Since("1.3.0") + @Since("1.2.0") def this( numRows: Int, numCols: Int, @@ -549,8 +563,6 @@ class SparseMatrix( } } - /** - */ @Since("1.3.0") override def apply(i: Int, j: Int): Double = { val ind = index(i, j) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index a37aca99d5e7..4dcf8f28f202 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -31,6 +31,7 @@ case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VTyp * :: Experimental :: * Represents QR factors. */ +@Since("1.5.0") @Experimental case class QRDecomposition[QType, RType](Q: QType, R: RType) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 3d577edbe23e..06ebb1586990 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -38,16 +38,19 @@ import org.apache.spark.sql.types._ * Note: Users should not implement this interface. */ @SQLUserDefinedType(udt = classOf[VectorUDT]) +@Since("1.0.0") sealed trait Vector extends Serializable { /** * Size of the vector. */ + @Since("1.0.0") def size: Int /** * Converts the instance to a double array. */ + @Since("1.0.0") def toArray: Array[Double] override def equals(other: Any): Boolean = { @@ -99,11 +102,13 @@ sealed trait Vector extends Serializable { * Gets the value of the ith element. * @param i index */ + @Since("1.1.0") def apply(i: Int): Double = toBreeze(i) /** * Makes a deep copy of this vector. */ + @Since("1.1.0") def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } @@ -121,26 +126,31 @@ sealed trait Vector extends Serializable { * Number of active entries. An "active entry" is an element which is explicitly stored, * regardless of its value. Note that inactive entries have value 0. */ + @Since("1.4.0") def numActives: Int /** * Number of nonzero elements. This scans all active values and count nonzeros. */ + @Since("1.4.0") def numNonzeros: Int /** * Converts this vector to a sparse vector with all explicit zeros removed. */ + @Since("1.4.0") def toSparse: SparseVector /** * Converts this vector to a dense vector. */ + @Since("1.4.0") def toDense: DenseVector = new DenseVector(this.toArray) /** * Returns a vector in either dense or sparse format, whichever uses less storage. */ + @Since("1.4.0") def compressed: Vector = { val nnz = numNonzeros // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. @@ -155,6 +165,7 @@ sealed trait Vector extends Serializable { * Find the index of a maximal element. Returns the first maximal element in case of a tie. * Returns -1 if vector has length 0. */ + @Since("1.5.0") def argmax: Int } @@ -532,7 +543,8 @@ object Vectors { */ @Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) -class DenseVector(val values: Array[Double]) extends Vector { +class DenseVector @Since("1.0.0") ( + @Since("1.0.0") val values: Array[Double]) extends Vector { @Since("1.0.0") override def size: Int = values.length @@ -632,7 +644,9 @@ class DenseVector(val values: Array[Double]) extends Vector { @Since("1.3.0") object DenseVector { + /** Extracts the value array from a dense vector. */ + @Since("1.3.0") def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) } @@ -645,10 +659,10 @@ object DenseVector { */ @Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) -class SparseVector( - override val size: Int, - val indices: Array[Int], - val values: Array[Double]) extends Vector { +class SparseVector @Since("1.0.0") ( + @Since("1.0.0") override val size: Int, + @Since("1.0.0") val indices: Array[Int], + @Since("1.0.0") val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + @@ -819,6 +833,7 @@ class SparseVector( @Since("1.3.0") object SparseVector { + @Since("1.3.0") def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] = Some((sv.size, sv.indices, sv.values)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 94376c24a7ac..a33b6137cf9c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -131,10 +131,10 @@ private[mllib] object GridPartitioner { */ @Since("1.3.0") @Experimental -class BlockMatrix( - val blocks: RDD[((Int, Int), Matrix)], - val rowsPerBlock: Int, - val colsPerBlock: Int, +class BlockMatrix @Since("1.3.0") ( + @Since("1.3.0") val blocks: RDD[((Int, Int), Matrix)], + @Since("1.3.0") val rowsPerBlock: Int, + @Since("1.3.0") val colsPerBlock: Int, private var nRows: Long, private var nCols: Long) extends DistributedMatrix with Logging { @@ -171,7 +171,9 @@ class BlockMatrix( nCols } + @Since("1.3.0") val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt + @Since("1.3.0") val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt private[mllib] def createPartitioner(): GridPartitioner = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 4bb27ec84090..644f293d88a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -46,8 +46,8 @@ case class MatrixEntry(i: Long, j: Long, value: Double) */ @Since("1.0.0") @Experimental -class CoordinateMatrix( - val entries: RDD[MatrixEntry], +class CoordinateMatrix @Since("1.0.0") ( + @Since("1.0.0") val entries: RDD[MatrixEntry], private var nRows: Long, private var nCols: Long) extends DistributedMatrix { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala index e51327ebb7b5..db3433a5e245 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala @@ -28,9 +28,11 @@ import org.apache.spark.annotation.Since trait DistributedMatrix extends Serializable { /** Gets or computes the number of rows. */ + @Since("1.0.0") def numRows(): Long /** Gets or computes the number of columns. */ + @Since("1.0.0") def numCols(): Long /** Collects data and assembles a local dense breeze matrix (for test only). */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 6d2c05a47d04..b20ea0dc50da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -45,8 +45,8 @@ case class IndexedRow(index: Long, vector: Vector) */ @Since("1.0.0") @Experimental -class IndexedRowMatrix( - val rows: RDD[IndexedRow], +class IndexedRowMatrix @Since("1.0.0") ( + @Since("1.0.0") val rows: RDD[IndexedRow], private var nRows: Long, private var nCols: Int) extends DistributedMatrix { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 78036eba5c3e..9a423ddafdc0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -47,8 +47,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.0.0") @Experimental -class RowMatrix( - val rows: RDD[Vector], +class RowMatrix @Since("1.0.0") ( + @Since("1.0.0") val rows: RDD[Vector], private var nRows: Long, private var nCols: Int) extends DistributedMatrix with Logging { @@ -519,6 +519,7 @@ class RowMatrix( * @param computeQ whether to computeQ * @return QRDecomposition(Q, R), Q = null if computeQ = false. */ + @Since("1.5.0") def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { val col = numCols().toInt // split rows horizontally into smaller matrices, and compute QR for each of them From c3a54843c0c8a14059da4e6716c1ad45c69bbe6c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 22:31:23 -0700 Subject: [PATCH 0424/1824] [SPARK-10240] [SPARK-10242] [MLLIB] update since versions in mlilb.random and mllib.stat The same as #8241 but for `mllib.stat` and `mllib.random`. cc feynmanliang Author: Xiangrui Meng Closes #8439 from mengxr/SPARK-10242. --- .../mllib/random/RandomDataGenerator.scala | 43 ++++++++++-- .../spark/mllib/random/RandomRDDs.scala | 69 ++++++++++++++++--- .../distribution/MultivariateGaussian.scala | 6 +- .../spark/mllib/stat/test/TestResult.scala | 24 ++++--- 4 files changed, 117 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9349ecaa13f5..a2d85a68cd32 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.random import org.apache.commons.math3.distribution.{ExponentialDistribution, GammaDistribution, LogNormalDistribution, PoissonDistribution} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** @@ -28,17 +28,20 @@ import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} * Trait for random data generators that generate i.i.d. data. */ @DeveloperApi +@Since("1.1.0") trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** * Returns an i.i.d. sample as a generic type from an underlying distribution. */ + @Since("1.1.0") def nextValue(): T /** * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the * class when applicable for non-locking concurrent usage. */ + @Since("1.1.0") def copy(): RandomDataGenerator[T] } @@ -47,17 +50,21 @@ trait RandomDataGenerator[T] extends Pseudorandom with Serializable { * Generates i.i.d. samples from U[0.0, 1.0] */ @DeveloperApi +@Since("1.1.0") class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() + @Since("1.1.0") override def nextValue(): Double = { random.nextDouble() } + @Since("1.1.0") override def setSeed(seed: Long): Unit = random.setSeed(seed) + @Since("1.1.0") override def copy(): UniformGenerator = new UniformGenerator() } @@ -66,17 +73,21 @@ class UniformGenerator extends RandomDataGenerator[Double] { * Generates i.i.d. samples from the standard normal distribution. */ @DeveloperApi +@Since("1.1.0") class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() + @Since("1.1.0") override def nextValue(): Double = { random.nextGaussian() } + @Since("1.1.0") override def setSeed(seed: Long): Unit = random.setSeed(seed) + @Since("1.1.0") override def copy(): StandardNormalGenerator = new StandardNormalGenerator() } @@ -87,16 +98,21 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { * @param mean mean for the Poisson distribution. */ @DeveloperApi -class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { +@Since("1.1.0") +class PoissonGenerator @Since("1.1.0") ( + @Since("1.1.0") val mean: Double) extends RandomDataGenerator[Double] { private val rng = new PoissonDistribution(mean) + @Since("1.1.0") override def nextValue(): Double = rng.sample() + @Since("1.1.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.1.0") override def copy(): PoissonGenerator = new PoissonGenerator(mean) } @@ -107,16 +123,21 @@ class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { * @param mean mean for the exponential distribution. */ @DeveloperApi -class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class ExponentialGenerator @Since("1.3.0") ( + @Since("1.3.0") val mean: Double) extends RandomDataGenerator[Double] { private val rng = new ExponentialDistribution(mean) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): ExponentialGenerator = new ExponentialGenerator(mean) } @@ -128,16 +149,22 @@ class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double] * @param scale scale for the gamma distribution */ @DeveloperApi -class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class GammaGenerator @Since("1.3.0") ( + @Since("1.3.0") val shape: Double, + @Since("1.3.0") val scale: Double) extends RandomDataGenerator[Double] { private val rng = new GammaDistribution(shape, scale) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): GammaGenerator = new GammaGenerator(shape, scale) } @@ -150,15 +177,21 @@ class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGen * @param std standard deviation for the log normal distribution */ @DeveloperApi -class LogNormalGenerator(val mean: Double, val std: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class LogNormalGenerator @Since("1.3.0") ( + @Since("1.3.0") val mean: Double, + @Since("1.3.0") val std: Double) extends RandomDataGenerator[Double] { private val rng = new LogNormalDistribution(mean, std) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 174d5e0f6c9f..4dd5ea214d67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.random import scala.reflect.ClassTag import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} @@ -32,6 +32,7 @@ import org.apache.spark.util.Utils * Generator methods for creating RDDs comprised of `i.i.d.` samples from some distribution. */ @Experimental +@Since("1.1.0") object RandomRDDs { /** @@ -46,6 +47,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ `U(0.0, 1.0)`. */ + @Since("1.1.0") def uniformRDD( sc: SparkContext, size: Long, @@ -58,6 +60,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#uniformRDD]]. */ + @Since("1.1.0") def uniformJavaRDD( jsc: JavaSparkContext, size: Long, @@ -69,6 +72,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaRDD]] with the default seed. */ + @Since("1.1.0") def uniformJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions)) } @@ -76,6 +80,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def uniformJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size)) } @@ -92,6 +97,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ N(0.0, 1.0). */ + @Since("1.1.0") def normalRDD( sc: SparkContext, size: Long, @@ -104,6 +110,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#normalRDD]]. */ + @Since("1.1.0") def normalJavaRDD( jsc: JavaSparkContext, size: Long, @@ -115,6 +122,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaRDD]] with the default seed. */ + @Since("1.1.0") def normalJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions)) } @@ -122,6 +130,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def normalJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size)) } @@ -137,6 +146,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.1.0") def poissonRDD( sc: SparkContext, mean: Double, @@ -150,6 +160,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#poissonRDD]]. */ + @Since("1.1.0") def poissonJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -162,6 +173,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaRDD]] with the default seed. */ + @Since("1.1.0") def poissonJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -173,6 +185,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def poissonJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size)) } @@ -188,6 +201,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def exponentialRDD( sc: SparkContext, mean: Double, @@ -201,6 +215,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#exponentialRDD]]. */ + @Since("1.3.0") def exponentialJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -213,6 +228,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaRDD]] with the default seed. */ + @Since("1.3.0") def exponentialJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -224,6 +240,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def exponentialJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(exponentialRDD(jsc.sc, mean, size)) } @@ -240,6 +257,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def gammaRDD( sc: SparkContext, shape: Double, @@ -254,6 +272,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#gammaRDD]]. */ + @Since("1.3.0") def gammaJavaRDD( jsc: JavaSparkContext, shape: Double, @@ -267,6 +286,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaRDD]] with the default seed. */ + @Since("1.3.0") def gammaJavaRDD( jsc: JavaSparkContext, shape: Double, @@ -279,11 +299,12 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def gammaJavaRDD( - jsc: JavaSparkContext, - shape: Double, - scale: Double, - size: Long): JavaDoubleRDD = { + jsc: JavaSparkContext, + shape: Double, + scale: Double, + size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(gammaRDD(jsc.sc, shape, scale, size)) } @@ -299,6 +320,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def logNormalRDD( sc: SparkContext, mean: Double, @@ -313,6 +335,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#logNormalRDD]]. */ + @Since("1.3.0") def logNormalJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -326,6 +349,7 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaRDD]] with the default seed. */ + @Since("1.3.0") def logNormalJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -338,11 +362,12 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def logNormalJavaRDD( - jsc: JavaSparkContext, - mean: Double, - std: Double, - size: Long): JavaDoubleRDD = { + jsc: JavaSparkContext, + mean: Double, + std: Double, + size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(logNormalRDD(jsc.sc, mean, std, size)) } @@ -359,6 +384,7 @@ object RandomRDDs { * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi + @Since("1.1.0") def randomRDD[T: ClassTag]( sc: SparkContext, generator: RandomDataGenerator[T], @@ -381,6 +407,7 @@ object RandomRDDs { * @param seed Seed for the RNG that generates the seed for the generator in each partition. * @return RDD[Vector] with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. */ + @Since("1.1.0") def uniformVectorRDD( sc: SparkContext, numRows: Long, @@ -394,6 +421,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#uniformVectorRDD]]. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -406,6 +434,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -417,6 +446,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -435,6 +465,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ `N(0.0, 1.0)`. */ + @Since("1.1.0") def normalVectorRDD( sc: SparkContext, numRows: Long, @@ -448,6 +479,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#normalVectorRDD]]. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -460,6 +492,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -471,6 +504,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -491,6 +525,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples. */ + @Since("1.3.0") def logNormalVectorRDD( sc: SparkContext, mean: Double, @@ -507,6 +542,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#logNormalVectorRDD]]. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -521,6 +557,7 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -535,6 +572,7 @@ object RandomRDDs { * [[RandomRDDs#logNormalJavaVectorRDD]] with the default number of partitions and * the default seed. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -556,6 +594,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Pois(mean). */ + @Since("1.1.0") def poissonVectorRDD( sc: SparkContext, mean: Double, @@ -570,6 +609,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#poissonVectorRDD]]. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -583,6 +623,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -595,6 +636,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -615,6 +657,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ + @Since("1.3.0") def exponentialVectorRDD( sc: SparkContext, mean: Double, @@ -630,6 +673,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#exponentialVectorRDD]]. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -643,6 +687,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -656,6 +701,7 @@ object RandomRDDs { * [[RandomRDDs#exponentialJavaVectorRDD]] with the default number of partitions * and the default seed. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -678,6 +724,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ + @Since("1.3.0") def gammaVectorRDD( sc: SparkContext, shape: Double, @@ -693,6 +740,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#gammaVectorRDD]]. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -707,6 +755,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -720,6 +769,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -744,6 +794,7 @@ object RandomRDDs { * @return RDD[Vector] with vectors containing `i.i.d.` samples produced by generator. */ @DeveloperApi + @Since("1.1.0") def randomVectorRDD(sc: SparkContext, generator: RandomDataGenerator[Double], numRows: Long, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index bd4d81390bfa..92a5af708d04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -35,9 +35,9 @@ import org.apache.spark.mllib.util.MLUtils */ @Since("1.3.0") @DeveloperApi -class MultivariateGaussian ( - val mu: Vector, - val sigma: Matrix) extends Serializable { +class MultivariateGaussian @Since("1.3.0") ( + @Since("1.3.0") val mu: Vector, + @Since("1.3.0") val sigma: Matrix) extends Serializable { require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index f44be1370669..d01b3707be94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat.test -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: @@ -25,28 +25,33 @@ import org.apache.spark.annotation.Experimental * @tparam DF Return type of `degreesOfFreedom`. */ @Experimental +@Since("1.1.0") trait TestResult[DF] { /** * The probability of obtaining a test statistic result at least as extreme as the one that was * actually observed, assuming that the null hypothesis is true. */ + @Since("1.1.0") def pValue: Double /** * Returns the degree(s) of freedom of the hypothesis test. * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility. */ + @Since("1.1.0") def degreesOfFreedom: DF /** * Test statistic. */ + @Since("1.1.0") def statistic: Double /** * Null hypothesis of the test. */ + @Since("1.1.0") def nullHypothesis: String /** @@ -78,11 +83,12 @@ trait TestResult[DF] { * Object containing the test results for the chi-squared hypothesis test. */ @Experimental +@Since("1.1.0") class ChiSqTestResult private[stat] (override val pValue: Double, - override val degreesOfFreedom: Int, - override val statistic: Double, - val method: String, - override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.1.0") override val degreesOfFreedom: Int, + @Since("1.1.0") override val statistic: Double, + @Since("1.1.0") val method: String, + @Since("1.1.0") override val nullHypothesis: String) extends TestResult[Int] { override def toString: String = { "Chi squared test summary:\n" + @@ -96,11 +102,13 @@ class ChiSqTestResult private[stat] (override val pValue: Double, * Object containing the test results for the Kolmogorov-Smirnov test. */ @Experimental +@Since("1.5.0") class KolmogorovSmirnovTestResult private[stat] ( - override val pValue: Double, - override val statistic: Double, - override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.5.0") override val pValue: Double, + @Since("1.5.0") override val statistic: Double, + @Since("1.5.0") override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.5.0") override val degreesOfFreedom = 0 override def toString: String = { From d703372f86d6a59383ba8569fcd9d379849cffbf Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 22:33:48 -0700 Subject: [PATCH 0425/1824] [SPARK-10234] [MLLIB] update since version in mllib.clustering Same as #8421 but for `mllib.clustering`. cc feynmanliang yu-iskw Author: Xiangrui Meng Closes #8435 from mengxr/SPARK-10234. --- .../mllib/clustering/GaussianMixture.scala | 1 + .../clustering/GaussianMixtureModel.scala | 8 +++--- .../spark/mllib/clustering/KMeans.scala | 1 + .../spark/mllib/clustering/KMeansModel.scala | 4 +-- .../spark/mllib/clustering/LDAModel.scala | 28 ++++++++++++++----- .../clustering/PowerIterationClustering.scala | 10 +++++-- .../mllib/clustering/StreamingKMeans.scala | 15 +++++----- 7 files changed, 44 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index daa947e81d44..f82bd82c2037 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -53,6 +53,7 @@ import org.apache.spark.util.Utils * @param maxIterations The maximum number of iterations to perform */ @Experimental +@Since("1.3.0") class GaussianMixture private ( private var k: Int, private var convergenceTol: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 1a10a8b62421..7f6163e04bf1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -46,9 +46,9 @@ import org.apache.spark.sql.{SQLContext, Row} */ @Since("1.3.0") @Experimental -class GaussianMixtureModel( - val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { +class GaussianMixtureModel @Since("1.3.0") ( + @Since("1.3.0") val weights: Array[Double], + @Since("1.3.0") val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") @@ -178,7 +178,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { (weight, new MultivariateGaussian(mu, sigma)) }.unzip - return new GaussianMixtureModel(weights.toArray, gaussians.toArray) + new GaussianMixtureModel(weights.toArray, gaussians.toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 3e9545a74bef..46920fffe6e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -37,6 +37,7 @@ import org.apache.spark.util.random.XORShiftRandom * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given * to it should be cached by the user. */ +@Since("0.8.0") class KMeans private ( private var k: Int, private var maxIterations: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index e425ecdd481c..a74158498272 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -37,8 +37,8 @@ import org.apache.spark.sql.Row * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel ( - val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable { +class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) + extends Saveable with Serializable with PMMLExportable { /** * A Java-friendly constructor that takes an Iterable of Vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 432bbedc8d6f..15129e0dd5a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -43,12 +43,15 @@ import org.apache.spark.util.BoundedPriorityQueue * including local and distributed data structures. */ @Experimental +@Since("1.3.0") abstract class LDAModel private[clustering] extends Saveable { /** Number of topics */ + @Since("1.3.0") def k: Int /** Vocabulary size (number of terms or terms in the vocabulary) */ + @Since("1.3.0") def vocabSize: Int /** @@ -57,6 +60,7 @@ abstract class LDAModel private[clustering] extends Saveable { * * This is the parameter to a Dirichlet distribution. */ + @Since("1.5.0") def docConcentration: Vector /** @@ -68,6 +72,7 @@ abstract class LDAModel private[clustering] extends Saveable { * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ + @Since("1.5.0") def topicConcentration: Double /** @@ -81,6 +86,7 @@ abstract class LDAModel private[clustering] extends Saveable { * This is a matrix of size vocabSize x k, where each column is a topic. * No guarantees are given about the ordering of the topics. */ + @Since("1.3.0") def topicsMatrix: Matrix /** @@ -91,6 +97,7 @@ abstract class LDAModel private[clustering] extends Saveable { * (term indices, term weights in topic). * Each topic's terms are sorted in order of decreasing weight. */ + @Since("1.3.0") def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] /** @@ -102,6 +109,7 @@ abstract class LDAModel private[clustering] extends Saveable { * (term indices, term weights in topic). * Each topic's terms are sorted in order of decreasing weight. */ + @Since("1.3.0") def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) /* TODO (once LDA can be trained with Strings or given a dictionary) @@ -185,10 +193,11 @@ abstract class LDAModel private[clustering] extends Saveable { * @param topics Inferred topics (vocabSize x k matrix). */ @Experimental +@Since("1.3.0") class LocalLDAModel private[clustering] ( - val topics: Matrix, - override val docConcentration: Vector, - override val topicConcentration: Double, + @Since("1.3.0") val topics: Matrix, + @Since("1.5.0") override val docConcentration: Vector, + @Since("1.5.0") override val topicConcentration: Double, override protected[clustering] val gammaShape: Double = 100) extends LDAModel with Serializable { @@ -376,6 +385,7 @@ class LocalLDAModel private[clustering] ( } @Experimental +@Since("1.5.0") object LocalLDAModel extends Loader[LocalLDAModel] { private object SaveLoadV1_0 { @@ -479,13 +489,14 @@ object LocalLDAModel extends Loader[LocalLDAModel] { * than the [[LocalLDAModel]]. */ @Experimental +@Since("1.3.0") class DistributedLDAModel private[clustering] ( private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], private[clustering] val globalTopicTotals: LDA.TopicCounts, - val k: Int, - val vocabSize: Int, - override val docConcentration: Vector, - override val topicConcentration: Double, + @Since("1.3.0") val k: Int, + @Since("1.3.0") val vocabSize: Int, + @Since("1.5.0") override val docConcentration: Vector, + @Since("1.5.0") override val topicConcentration: Double, private[spark] val iterationTimes: Array[Double], override protected[clustering] val gammaShape: Double = 100) extends LDAModel { @@ -603,6 +614,7 @@ class DistributedLDAModel private[clustering] ( * (term indices, topic indices). Note that terms will be omitted if not present in * the document. */ + @Since("1.5.0") lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = { // For reference, compare the below code with the core part of EMLDAOptimizer.next(). val eta = topicConcentration @@ -634,6 +646,7 @@ class DistributedLDAModel private[clustering] ( } /** Java-friendly version of [[topicAssignments]] */ + @Since("1.5.0") lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = { topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD() } @@ -770,6 +783,7 @@ class DistributedLDAModel private[clustering] ( @Experimental +@Since("1.5.0") object DistributedLDAModel extends Loader[DistributedLDAModel] { private object SaveLoadV1_0 { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 396b36f2f645..da234bdbb29e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -42,9 +42,10 @@ import org.apache.spark.{Logging, SparkContext, SparkException} */ @Since("1.3.0") @Experimental -class PowerIterationClusteringModel( - val k: Int, - val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { +class PowerIterationClusteringModel @Since("1.3.0") ( + @Since("1.3.0") val k: Int, + @Since("1.3.0") val assignments: RDD[PowerIterationClustering.Assignment]) + extends Saveable with Serializable { @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { @@ -56,6 +57,8 @@ class PowerIterationClusteringModel( @Since("1.4.0") object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { + + @Since("1.4.0") override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) } @@ -120,6 +123,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ @Experimental +@Since("1.3.0") class PowerIterationClustering private[clustering] ( private var k: Int, private var maxIterations: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 41f2668ec6a7..1d50ffec96fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -66,9 +66,10 @@ import org.apache.spark.util.random.XORShiftRandom */ @Since("1.2.0") @Experimental -class StreamingKMeansModel( - override val clusterCenters: Array[Vector], - val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { +class StreamingKMeansModel @Since("1.2.0") ( + @Since("1.2.0") override val clusterCenters: Array[Vector], + @Since("1.2.0") val clusterWeights: Array[Double]) + extends KMeansModel(clusterCenters) with Logging { /** * Perform a k-means update on a batch of data. @@ -168,10 +169,10 @@ class StreamingKMeansModel( */ @Since("1.2.0") @Experimental -class StreamingKMeans( - var k: Int, - var decayFactor: Double, - var timeUnit: String) extends Logging with Serializable { +class StreamingKMeans @Since("1.2.0") ( + @Since("1.2.0") var k: Int, + @Since("1.2.0") var decayFactor: Double, + @Since("1.2.0") var timeUnit: String) extends Logging with Serializable { @Since("1.2.0") def this() = this(2, 1.0, StreamingKMeans.BATCHES) From fb7e12fe2e14af8de4c206ca8096b2e8113bfddc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 22:35:49 -0700 Subject: [PATCH 0426/1824] [SPARK-10243] [MLLIB] update since versions in mllib.tree Same as #8421 but for `mllib.tree`. cc jkbradley Author: Xiangrui Meng Closes #8442 from mengxr/SPARK-10236. --- .../spark/mllib/tree/DecisionTree.scala | 3 +- .../mllib/tree/GradientBoostedTrees.scala | 2 +- .../spark/mllib/tree/configuration/Algo.scala | 2 ++ .../tree/configuration/BoostingStrategy.scala | 12 ++++---- .../tree/configuration/FeatureType.scala | 2 ++ .../tree/configuration/QuantileStrategy.scala | 2 ++ .../mllib/tree/configuration/Strategy.scala | 29 ++++++++++--------- .../mllib/tree/model/DecisionTreeModel.scala | 5 +++- .../apache/spark/mllib/tree/model/Node.scala | 18 ++++++------ .../spark/mllib/tree/model/Predict.scala | 6 ++-- .../apache/spark/mllib/tree/model/Split.scala | 8 ++--- .../mllib/tree/model/treeEnsembleModels.scala | 12 ++++---- 12 files changed, 57 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 972841015d4f..4a77d4adcd86 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -46,7 +46,8 @@ import org.apache.spark.util.random.XORShiftRandom */ @Since("1.0.0") @Experimental -class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { +class DecisionTree @Since("1.0.0") (private val strategy: Strategy) + extends Serializable with Logging { strategy.assertValid() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index e750408600c3..95ed48cea671 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -51,7 +51,7 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.2.0") @Experimental -class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) +class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 8301ad160836..853c7319ec44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -26,7 +26,9 @@ import org.apache.spark.annotation.{Experimental, Since} @Since("1.0.0") @Experimental object Algo extends Enumeration { + @Since("1.0.0") type Algo = Value + @Since("1.0.0") val Classification, Regression = Value private[mllib] def fromString(name: String): Algo = name match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 7c569981977b..b5c72fba3ede 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -41,14 +41,14 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} */ @Since("1.2.0") @Experimental -case class BoostingStrategy( +case class BoostingStrategy @Since("1.4.0") ( // Required boosting parameters - @BeanProperty var treeStrategy: Strategy, - @BeanProperty var loss: Loss, + @Since("1.2.0") @BeanProperty var treeStrategy: Strategy, + @Since("1.2.0") @BeanProperty var loss: Loss, // Optional boosting parameters - @BeanProperty var numIterations: Int = 100, - @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var validationTol: Double = 1e-5) extends Serializable { + @Since("1.2.0") @BeanProperty var numIterations: Int = 100, + @Since("1.2.0") @BeanProperty var learningRate: Double = 0.1, + @Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable { /** * Check validity of parameters. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index bb7c7ee4f964..4e0cd473def0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since} @Since("1.0.0") @Experimental object FeatureType extends Enumeration { + @Since("1.0.0") type FeatureType = Value + @Since("1.0.0") val Continuous, Categorical = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 904e42deebb5..8262db8a4f11 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since} @Since("1.0.0") @Experimental object QuantileStrategy extends Enumeration { + @Since("1.0.0") type QuantileStrategy = Value + @Since("1.0.0") val Sort, MinMax, ApproxHist = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b74e3f1f4652..89cc13b7c06c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -69,20 +69,20 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ */ @Since("1.0.0") @Experimental -class Strategy ( - @BeanProperty var algo: Algo, - @BeanProperty var impurity: Impurity, - @BeanProperty var maxDepth: Int, - @BeanProperty var numClasses: Int = 2, - @BeanProperty var maxBins: Int = 32, - @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, - @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - @BeanProperty var minInstancesPerNode: Int = 1, - @BeanProperty var minInfoGain: Double = 0.0, - @BeanProperty var maxMemoryInMB: Int = 256, - @BeanProperty var subsamplingRate: Double = 1, - @BeanProperty var useNodeIdCache: Boolean = false, - @BeanProperty var checkpointInterval: Int = 10) extends Serializable { +class Strategy @Since("1.3.0") ( + @Since("1.0.0") @BeanProperty var algo: Algo, + @Since("1.0.0") @BeanProperty var impurity: Impurity, + @Since("1.0.0") @BeanProperty var maxDepth: Int, + @Since("1.2.0") @BeanProperty var numClasses: Int = 2, + @Since("1.0.0") @BeanProperty var maxBins: Int = 32, + @Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, + @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1, + @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0, + @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, + @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, + @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, + @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { /** */ @@ -206,6 +206,7 @@ object Strategy { } @deprecated("Use Strategy.defaultStrategy instead.", "1.5.0") + @Since("1.2.0") def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 3eefd135f783..e1bf23f4c34b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -43,7 +43,9 @@ import org.apache.spark.util.Utils */ @Since("1.0.0") @Experimental -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable { +class DecisionTreeModel @Since("1.0.0") ( + @Since("1.0.0") val topNode: Node, + @Since("1.0.0") val algo: Algo) extends Serializable with Saveable { /** * Predict values for a single data point using the model trained. @@ -110,6 +112,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable /** * Print the full model to a string. */ + @Since("1.2.0") def toDebugString: String = { val header = toString + "\n" header + topNode.subtreeToString(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 8c54c5510723..ea6e5aa5d94e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -41,15 +41,15 @@ import org.apache.spark.mllib.linalg.Vector */ @Since("1.0.0") @DeveloperApi -class Node ( - val id: Int, - var predict: Predict, - var impurity: Double, - var isLeaf: Boolean, - var split: Option[Split], - var leftNode: Option[Node], - var rightNode: Option[Node], - var stats: Option[InformationGainStats]) extends Serializable with Logging { +class Node @Since("1.2.0") ( + @Since("1.0.0") val id: Int, + @Since("1.0.0") var predict: Predict, + @Since("1.2.0") var impurity: Double, + @Since("1.0.0") var isLeaf: Boolean, + @Since("1.0.0") var split: Option[Split], + @Since("1.0.0") var leftNode: Option[Node], + @Since("1.0.0") var rightNode: Option[Node], + @Since("1.0.0") var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString: String = { s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 965784051ede..06ceff19d863 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -26,9 +26,9 @@ import org.apache.spark.annotation.{DeveloperApi, Since} */ @Since("1.2.0") @DeveloperApi -class Predict( - val predict: Double, - val prob: Double = 0.0) extends Serializable { +class Predict @Since("1.2.0") ( + @Since("1.2.0") val predict: Double, + @Since("1.2.0") val prob: Double = 0.0) extends Serializable { override def toString: String = s"$predict (prob = $prob)" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 45db83ae3a1f..b85a66c05a81 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -34,10 +34,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType @Since("1.0.0") @DeveloperApi case class Split( - feature: Int, - threshold: Double, - featureType: FeatureType, - categories: List[Double]) { + @Since("1.0.0") feature: Int, + @Since("1.0.0") threshold: Double, + @Since("1.0.0") featureType: FeatureType, + @Since("1.0.0") categories: List[Double]) { override def toString: String = { s"Feature = $feature, threshold = $threshold, featureType = $featureType, " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 19571447a2c5..df5b8feab5d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -48,7 +48,9 @@ import org.apache.spark.util.Utils */ @Since("1.2.0") @Experimental -class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) +class RandomForestModel @Since("1.2.0") ( + @Since("1.2.0") override val algo: Algo, + @Since("1.2.0") override val trees: Array[DecisionTreeModel]) extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), combiningStrategy = if (algo == Classification) Vote else Average) with Saveable { @@ -115,10 +117,10 @@ object RandomForestModel extends Loader[RandomForestModel] { */ @Since("1.2.0") @Experimental -class GradientBoostedTreesModel( - override val algo: Algo, - override val trees: Array[DecisionTreeModel], - override val treeWeights: Array[Double]) +class GradientBoostedTreesModel @Since("1.2.0") ( + @Since("1.2.0") override val algo: Algo, + @Since("1.2.0") override val trees: Array[DecisionTreeModel], + @Since("1.2.0") override val treeWeights: Array[Double]) extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) with Saveable { From 4657fa1f37d41dd4c7240a960342b68c7c591f48 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 22:49:33 -0700 Subject: [PATCH 0427/1824] [SPARK-10235] [MLLIB] update since versions in mllib.regression Same as #8421 but for `mllib.regression`. cc freeman-lab dbtsai Author: Xiangrui Meng Closes #8426 from mengxr/SPARK-10235 and squashes the following commits: 6cd28e4 [Xiangrui Meng] update since versions in mllib.regression --- .../regression/GeneralizedLinearAlgorithm.scala | 6 ++++-- .../mllib/regression/IsotonicRegression.scala | 16 +++++++++------- .../spark/mllib/regression/LabeledPoint.scala | 5 +++-- .../apache/spark/mllib/regression/Lasso.scala | 9 ++++++--- .../mllib/regression/LinearRegression.scala | 9 ++++++--- .../spark/mllib/regression/RidgeRegression.scala | 12 +++++++----- .../regression/StreamingLinearAlgorithm.scala | 8 +++----- .../StreamingLinearRegressionWithSGD.scala | 11 +++++++++-- 8 files changed, 47 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 509f6a2d169c..7e3b4d5648fe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -38,7 +38,9 @@ import org.apache.spark.storage.StorageLevel */ @Since("0.8.0") @DeveloperApi -abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) +abstract class GeneralizedLinearModel @Since("1.0.0") ( + @Since("1.0.0") val weights: Vector, + @Since("0.8.0") val intercept: Double) extends Serializable { /** @@ -107,7 +109,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * The optimizer to solve the problem. * */ - @Since("1.0.0") + @Since("0.8.0") def optimizer: Optimizer /** Whether to add intercept (default: false). */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 31ca7c2f207d..877d31ba4130 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -50,10 +50,10 @@ import org.apache.spark.sql.SQLContext */ @Since("1.3.0") @Experimental -class IsotonicRegressionModel ( - val boundaries: Array[Double], - val predictions: Array[Double], - val isotonic: Boolean) extends Serializable with Saveable { +class IsotonicRegressionModel @Since("1.3.0") ( + @Since("1.3.0") val boundaries: Array[Double], + @Since("1.3.0") val predictions: Array[Double], + @Since("1.3.0") val isotonic: Boolean) extends Serializable with Saveable { private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse @@ -63,7 +63,6 @@ class IsotonicRegressionModel ( /** * A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. - * */ @Since("1.4.0") def this(boundaries: java.lang.Iterable[Double], @@ -214,8 +213,6 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } } - /** - */ @Since("1.4.0") override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { implicit val formats = DefaultFormats @@ -256,6 +253,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { * @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] */ @Experimental +@Since("1.3.0") class IsotonicRegression private (private var isotonic: Boolean) extends Serializable { /** @@ -263,6 +261,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * * @return New instance of IsotonicRegression. */ + @Since("1.3.0") def this() = this(true) /** @@ -271,6 +270,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * @param isotonic Isotonic (increasing) or antitonic (decreasing) sequence. * @return This instance of IsotonicRegression. */ + @Since("1.3.0") def setIsotonic(isotonic: Boolean): this.type = { this.isotonic = isotonic this @@ -286,6 +286,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * the algorithm is executed. * @return Isotonic regression model. */ + @Since("1.3.0") def run(input: RDD[(Double, Double, Double)]): IsotonicRegressionModel = { val preprocessedInput = if (isotonic) { input @@ -311,6 +312,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * the algorithm is executed. * @return Isotonic regression model. */ + @Since("1.3.0") def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = { run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index f7fe1b7b21fc..c284ad232537 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -29,11 +29,12 @@ import org.apache.spark.SparkException * * @param label Label for this data point. * @param features List of features for this data point. - * */ @Since("0.8.0") @BeanInfo -case class LabeledPoint(label: Double, features: Vector) { +case class LabeledPoint @Since("1.0.0") ( + @Since("0.8.0") label: Double, + @Since("1.0.0") features: Vector) { override def toString: String = { s"($label,$features)" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 556411a366bd..a9aba173fa0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -34,9 +34,9 @@ import org.apache.spark.rdd.RDD * */ @Since("0.8.0") -class LassoModel ( - override val weights: Vector, - override val intercept: Double) +class LassoModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { @@ -84,6 +84,7 @@ object LassoModel extends Loader[LassoModel] { * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class LassoWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -93,6 +94,7 @@ class LassoWithSGD private ( private val gradient = new LeastSquaresGradient() private val updater = new L1Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -103,6 +105,7 @@ class LassoWithSGD private ( * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, * regParam: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 00ab06e3ba73..4996ace5df85 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -34,9 +34,9 @@ import org.apache.spark.rdd.RDD * */ @Since("0.8.0") -class LinearRegressionModel ( - override val weights: Vector, - override val intercept: Double) +class LinearRegressionModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { @@ -85,6 +85,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] { * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class LinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -93,6 +94,7 @@ class LinearRegressionWithSGD private[mllib] ( private val gradient = new LeastSquaresGradient() private val updater = new SimpleUpdater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -102,6 +104,7 @@ class LinearRegressionWithSGD private[mllib] ( * Construct a LinearRegression object with default parameters: {stepSize: 1.0, * numIterations: 100, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 21a791d98b2c..0a44ff559d55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -35,9 +35,9 @@ import org.apache.spark.rdd.RDD * */ @Since("0.8.0") -class RidgeRegressionModel ( - override val weights: Vector, - override val intercept: Double) +class RidgeRegressionModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { @@ -85,6 +85,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] { * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class RidgeRegressionWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -94,7 +95,7 @@ class RidgeRegressionWithSGD private ( private val gradient = new LeastSquaresGradient() private val updater = new SquaredL2Updater() - + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -105,6 +106,7 @@ class RidgeRegressionWithSGD private ( * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, * regParam: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -134,7 +136,7 @@ object RidgeRegressionWithSGD { * the number of features in the data. * */ - @Since("0.8.0") + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index cd3ed8a1549d..73948b2d9851 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream @@ -83,9 +83,8 @@ abstract class StreamingLinearAlgorithm[ * batch of data from the stream. * * @param data DStream containing labeled data - * */ - @Since("1.3.0") + @Since("1.1.0") def trainOn(data: DStream[LabeledPoint]): Unit = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting training.") @@ -105,7 +104,6 @@ abstract class StreamingLinearAlgorithm[ /** * Java-friendly version of `trainOn`. - * */ @Since("1.3.0") def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream) @@ -129,7 +127,7 @@ abstract class StreamingLinearAlgorithm[ * Java-friendly version of `predictOn`. * */ - @Since("1.1.0") + @Since("1.3.0") def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = { JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 26654e4a0683..fe1d487cdd07 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.Vector /** @@ -41,6 +41,7 @@ import org.apache.spark.mllib.linalg.Vector * .trainOn(DStream) */ @Experimental +@Since("1.1.0") class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -54,8 +55,10 @@ class StreamingLinearRegressionWithSGD private[mllib] ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ + @Since("1.1.0") def this() = this(0.1, 50, 1.0) + @Since("1.1.0") val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) protected var model: Option[LinearRegressionModel] = None @@ -63,6 +66,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the step size for gradient descent. Default: 0.1. */ + @Since("1.1.0") def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) this @@ -71,6 +75,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the number of iterations of gradient descent to run per update. Default: 50. */ + @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) this @@ -79,6 +84,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the fraction of each batch to use for updates. Default: 1.0. */ + @Since("1.1.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this @@ -87,6 +93,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the initial weights. */ + @Since("1.1.0") def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this @@ -95,9 +102,9 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the convergence tolerance. Default: 0.001. */ + @Since("1.5.0") def setConvergenceTol(tolerance: Double): this.type = { this.algorithm.optimizer.setConvergenceTol(tolerance) this } - } From 321d7759691bed9867b1f0470f12eab2faa50aff Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 23:45:41 -0700 Subject: [PATCH 0428/1824] [SPARK-10236] [MLLIB] update since versions in mllib.feature Same as #8421 but for `mllib.feature`. cc dbtsai Author: Xiangrui Meng Closes #8449 from mengxr/SPARK-10236.feature and squashes the following commits: 0e8d658 [Xiangrui Meng] remove unnecessary comment ad70b03 [Xiangrui Meng] update since versions in mllib.feature --- .../mllib/clustering/PowerIterationClustering.scala | 2 -- .../apache/spark/mllib/feature/ChiSqSelector.scala | 4 ++-- .../spark/mllib/feature/ElementwiseProduct.scala | 3 ++- .../scala/org/apache/spark/mllib/feature/IDF.scala | 6 ++++-- .../org/apache/spark/mllib/feature/Normalizer.scala | 2 +- .../scala/org/apache/spark/mllib/feature/PCA.scala | 7 +++++-- .../apache/spark/mllib/feature/StandardScaler.scala | 12 ++++++------ .../org/apache/spark/mllib/feature/Word2Vec.scala | 1 + 8 files changed, 21 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index da234bdbb29e..6c76e26fd162 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -71,8 +71,6 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" - /** - */ @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { val sqlContext = new SQLContext(sc) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index fdd974d7a391..4743cfd1a2c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD */ @Since("1.3.0") @Experimental -class ChiSqSelectorModel ( +class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer { require(isSorted(selectedFeatures), "Array has to be sorted asc") @@ -112,7 +112,7 @@ class ChiSqSelectorModel ( */ @Since("1.3.0") @Experimental -class ChiSqSelector ( +class ChiSqSelector @Since("1.3.0") ( @Since("1.3.0") val numTopFeatures: Int) extends Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index 33e2d17bb472..d0a6cf61687a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -29,7 +29,8 @@ import org.apache.spark.mllib.linalg._ */ @Since("1.4.0") @Experimental -class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { +class ElementwiseProduct @Since("1.4.0") ( + @Since("1.4.0") val scalingVec: Vector) extends VectorTransformer { /** * Does the hadamard product transformation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index d5353ddd972e..68078ccfa3d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -39,8 +39,9 @@ import org.apache.spark.rdd.RDD */ @Since("1.1.0") @Experimental -class IDF(val minDocFreq: Int) { +class IDF @Since("1.2.0") (@Since("1.2.0") val minDocFreq: Int) { + @Since("1.1.0") def this() = this(0) // TODO: Allow different IDF formulations. @@ -162,7 +163,8 @@ private object IDF { * Represents an IDF model that can transform term frequency vectors. */ @Experimental -class IDFModel private[spark] (val idf: Vector) extends Serializable { +@Since("1.1.0") +class IDFModel private[spark] (@Since("1.1.0") val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 0e070257d9fb..8d5a22520d6b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors */ @Since("1.1.0") @Experimental -class Normalizer(p: Double) extends VectorTransformer { +class Normalizer @Since("1.1.0") (p: Double) extends VectorTransformer { @Since("1.1.0") def this() = this(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index a48b7bba665d..ecb3c1e6c1c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.RDD * @param k number of principal components */ @Since("1.4.0") -class PCA(val k: Int) { +class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k") /** @@ -74,7 +74,10 @@ class PCA(val k: Int) { * @param k number of principal components. * @param pc a principal components Matrix. Each column is one principal component. */ -class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { +@Since("1.4.0") +class PCAModel private[spark] ( + @Since("1.4.0") val k: Int, + @Since("1.4.0") val pc: DenseMatrix) extends VectorTransformer { /** * Transform a vector by computed Principal Components. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index b95d5a899001..f018b453bae7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD */ @Since("1.1.0") @Experimental -class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { +class StandardScaler @Since("1.1.0") (withMean: Boolean, withStd: Boolean) extends Logging { @Since("1.1.0") def this() = this(false, true) @@ -74,11 +74,11 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { */ @Since("1.1.0") @Experimental -class StandardScalerModel ( - val std: Vector, - val mean: Vector, - var withStd: Boolean, - var withMean: Boolean) extends VectorTransformer { +class StandardScalerModel @Since("1.3.0") ( + @Since("1.3.0") val std: Vector, + @Since("1.1.0") val mean: Vector, + @Since("1.3.0") var withStd: Boolean, + @Since("1.3.0") var withMean: Boolean) extends VectorTransformer { /** */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index e6f45ae4b01d..36b124c5d296 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -436,6 +436,7 @@ class Word2Vec extends Serializable with Logging { * (i * vectorSize, i * vectorSize + vectorSize) */ @Experimental +@Since("1.1.0") class Word2VecModel private[mllib] ( private val wordIndex: Map[String, Int], private val wordVectors: Array[Float]) extends Serializable with Saveable { From 75d4773aa50e24972c533e8b48697fde586429eb Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 25 Aug 2015 23:48:16 -0700 Subject: [PATCH 0429/1824] [SPARK-9316] [SPARKR] Add support for filtering using `[` (synonym for filter / select) Add support for ``` df[df$name == "Smith", c(1,2)] df[df$age %in% c(19, 30), 1:2] ``` shivaram Author: felixcheung Closes #8394 from felixcheung/rsubset. --- R/pkg/R/DataFrame.R | 22 +++++++++++++++++++++- R/pkg/inst/tests/test_sparkSQL.R | 27 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ae1d912cf6da..a5162de705f8 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -985,9 +985,11 @@ setMethod("$<-", signature(x = "DataFrame"), x }) +setClassUnion("numericOrcharacter", c("numeric", "character")) + #' @rdname select #' @name [[ -setMethod("[[", signature(x = "DataFrame"), +setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"), function(x, i) { if (is.numeric(i)) { cols <- columns(x) @@ -1010,6 +1012,20 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), select(x, j) }) +#' @rdname select +#' @name [ +setMethod("[", signature(x = "DataFrame", i = "Column"), + function(x, i, j, ...) { + # It could handle i as "character" but it seems confusing and not required + # https://stat.ethz.ch/R-manual/R-devel/library/base/html/Extract.data.frame.html + filtered <- filter(x, i) + if (!missing(j)) { + filtered[, j] + } else { + filtered + } + }) + #' Select #' #' Selects a set of columns with names or Column expressions. @@ -1028,8 +1044,12 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), #' # Columns can also be selected using `[[` and `[` #' df[[2]] == df[["age"]] #' df[,2] == df[,"age"] +#' df[,c("name", "age")] #' # Similar to R data frames columns can also be selected using `$` #' df$age +#' # It can also be subset on rows and Columns +#' df[df$name == "Smith", c(1,2)] +#' df[df$age %in% c(19, 30), 1:2] #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 556b8c544705..ee48a3dc0cc0 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -587,6 +587,33 @@ test_that("select with column", { expect_equal(collect(select(df3, "x"))[[1, 1]], "x") }) +test_that("subsetting", { + # jsonFile returns columns in random order + df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + filtered <- df[df$age > 20,] + expect_equal(count(filtered), 1) + expect_equal(columns(filtered), c("name", "age")) + expect_equal(collect(filtered)$name, "Andy") + + df2 <- df[df$age == 19, 1] + expect_is(df2, "DataFrame") + expect_equal(count(df2), 1) + expect_equal(columns(df2), c("name")) + expect_equal(collect(df2)$name, "Justin") + + df3 <- df[df$age > 20, 2] + expect_equal(count(df3), 1) + expect_equal(columns(df3), c("age")) + + df4 <- df[df$age %in% c(19, 30), 1:2] + expect_equal(count(df4), 2) + expect_equal(columns(df4), c("name", "age")) + + df5 <- df[df$age %in% c(19), c(1,2)] + expect_equal(count(df5), 1) + expect_equal(columns(df5), c("name", "age")) +}) + test_that("selectExpr() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") From bb1640529725c6c38103b95af004f8bd90eeee5c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 26 Aug 2015 00:37:04 -0700 Subject: [PATCH 0430/1824] Closes #8443 From 6519fd06cc8175c9182ef16cf8a37d7f255eb846 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 26 Aug 2015 11:47:05 -0700 Subject: [PATCH 0431/1824] [SPARK-9665] [MLLIB] audit MLlib API annotations I only found `ml.NaiveBayes` missing `Experimental` annotation. This PR doesn't cover Python APIs. cc jkbradley Author: Xiangrui Meng Closes #8452 from mengxr/SPARK-9665. --- .../apache/spark/ml/classification/NaiveBayes.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 97cbaf1fa876..69cb88a7e671 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException -import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} -import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam} +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} -import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -59,6 +59,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { } /** + * :: Experimental :: * Naive Bayes Classifiers. * It supports both Multinomial NB * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) @@ -68,6 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). * The input feature values must be nonnegative. */ +@Experimental class NaiveBayes(override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams { @@ -101,11 +103,13 @@ class NaiveBayes(override val uid: String) } /** + * :: Experimental :: * Model produced by [[NaiveBayes]] * @param pi log of class priors, whose dimension is C (number of classes) * @param theta log of class conditional probabilities, whose dimension is C (number of classes) * by D (number of features) */ +@Experimental class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, From de7209c256aaf79a0978cfcf6e98bb013267b93a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 26 Aug 2015 12:19:36 -0700 Subject: [PATCH 0432/1824] HOTFIX: Increase PRB timeout --- dev/run-tests-jenkins | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index c4d39d95d589..f144c053046c 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -48,8 +48,8 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" # format: http://linux.die.net/man/1/timeout -# must be less than the timeout configured on Jenkins (currently 180m) -TESTS_TIMEOUT="175m" +# must be less than the timeout configured on Jenkins (currently 300m) +TESTS_TIMEOUT="250m" # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory. From 086d4681df3ebfccfc04188262c10482f44553b0 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 26 Aug 2015 14:02:19 -0700 Subject: [PATCH 0433/1824] [SPARK-10241] [MLLIB] update since versions in mllib.recommendation Same as #8421 but for `mllib.recommendation`. cc srowen coderxiang Author: Xiangrui Meng Closes #8432 from mengxr/SPARK-10241. --- .../spark/mllib/recommendation/ALS.scala | 22 ++++++++++++++++++- .../MatrixFactorizationModel.scala | 8 +++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index b27ef1b949e2..33aaf853e599 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -28,7 +28,10 @@ import org.apache.spark.storage.StorageLevel * A more compact class to represent a rating than Tuple3[Int, Int, Double]. */ @Since("0.8.0") -case class Rating(user: Int, product: Int, rating: Double) +case class Rating @Since("0.8.0") ( + @Since("0.8.0") user: Int, + @Since("0.8.0") product: Int, + @Since("0.8.0") rating: Double) /** * Alternating Least Squares matrix factorization. @@ -59,6 +62,7 @@ case class Rating(user: Int, product: Int, rating: Double) * indicated user * preferences rather than explicit ratings given to items. */ +@Since("0.8.0") class ALS private ( private var numUserBlocks: Int, private var numProductBlocks: Int, @@ -74,6 +78,7 @@ class ALS private ( * Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10, * lambda: 0.01, implicitPrefs: false, alpha: 1.0}. */ + @Since("0.8.0") def this() = this(-1, -1, 10, 10, 0.01, false, 1.0) /** If true, do alternating nonnegative least squares. */ @@ -90,6 +95,7 @@ class ALS private ( * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. */ + @Since("0.8.0") def setBlocks(numBlocks: Int): this.type = { this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks @@ -99,6 +105,7 @@ class ALS private ( /** * Set the number of user blocks to parallelize the computation. */ + @Since("1.1.0") def setUserBlocks(numUserBlocks: Int): this.type = { this.numUserBlocks = numUserBlocks this @@ -107,30 +114,35 @@ class ALS private ( /** * Set the number of product blocks to parallelize the computation. */ + @Since("1.1.0") def setProductBlocks(numProductBlocks: Int): this.type = { this.numProductBlocks = numProductBlocks this } /** Set the rank of the feature matrices computed (number of features). Default: 10. */ + @Since("0.8.0") def setRank(rank: Int): this.type = { this.rank = rank this } /** Set the number of iterations to run. Default: 10. */ + @Since("0.8.0") def setIterations(iterations: Int): this.type = { this.iterations = iterations this } /** Set the regularization parameter, lambda. Default: 0.01. */ + @Since("0.8.0") def setLambda(lambda: Double): this.type = { this.lambda = lambda this } /** Sets whether to use implicit preference. Default: false. */ + @Since("0.8.1") def setImplicitPrefs(implicitPrefs: Boolean): this.type = { this.implicitPrefs = implicitPrefs this @@ -139,12 +151,14 @@ class ALS private ( /** * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ + @Since("0.8.1") def setAlpha(alpha: Double): this.type = { this.alpha = alpha this } /** Sets a random seed to have deterministic results. */ + @Since("1.0.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -154,6 +168,7 @@ class ALS private ( * Set whether the least-squares problems solved at each iteration should have * nonnegativity constraints. */ + @Since("1.1.0") def setNonnegative(b: Boolean): this.type = { this.nonnegative = b this @@ -166,6 +181,7 @@ class ALS private ( * set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed. */ @DeveloperApi + @Since("1.1.0") def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { require(storageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") @@ -181,6 +197,7 @@ class ALS private ( * at the cost of speed. */ @DeveloperApi + @Since("1.3.0") def setFinalRDDStorageLevel(storageLevel: StorageLevel): this.type = { this.finalRDDStorageLevel = storageLevel this @@ -194,6 +211,7 @@ class ALS private ( * this setting is ignored. */ @DeveloperApi + @Since("1.4.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { this.checkpointInterval = checkpointInterval this @@ -203,6 +221,7 @@ class ALS private ( * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. */ + @Since("0.8.0") def run(ratings: RDD[Rating]): MatrixFactorizationModel = { val sc = ratings.context @@ -250,6 +269,7 @@ class ALS private ( /** * Java-friendly version of [[ALS.run]]. */ + @Since("1.3.0") def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ba4cfdcd9f1d..46562eb2ad0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -52,10 +52,10 @@ import org.apache.spark.storage.StorageLevel * and the features computed for this product. */ @Since("0.8.0") -class MatrixFactorizationModel( - val rank: Int, - val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) +class MatrixFactorizationModel @Since("0.8.0") ( + @Since("0.8.0") val rank: Int, + @Since("0.8.0") val userFeatures: RDD[(Int, Array[Double])], + @Since("0.8.0") val productFeatures: RDD[(Int, Array[Double])]) extends Saveable with Serializable with Logging { require(rank > 0) From d41d6c48207159490c1e1d9cc54015725cfa41b2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 26 Aug 2015 16:04:44 -0700 Subject: [PATCH 0434/1824] [SPARK-10305] [SQL] fix create DataFrame from Python class cc jkbradley Author: Davies Liu Closes #8470 from davies/fix_create_df. --- python/pyspark/sql/tests.py | 12 ++++++++++++ python/pyspark/sql/types.py | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index aacfb34c7761..cd32e26c64f2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -145,6 +145,12 @@ class PythonOnlyPoint(ExamplePoint): __UDT__ = PythonOnlyUDT() +class MyObject(object): + def __init__(self, key, value): + self.key = key + self.value = value + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -383,6 +389,12 @@ def test_infer_nested_schema(self): df = self.sqlCtx.inferSchema(rdd) self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_objects(self): + data = [MyObject(1, "1"), MyObject(2, "2")] + df = self.sqlCtx.createDataFrame(data) + self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) + self.assertEqual(df.first(), Row(key=1, value="1")) + def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") self.assertEquals(Row(col=None), df.first()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ed4e5b594bd6..94e581a78364 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -537,6 +537,9 @@ def toInternal(self, obj): return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields)) else: raise ValueError("Unexpected tuple %r with StructType" % obj) else: @@ -544,6 +547,9 @@ def toInternal(self, obj): return tuple(obj.get(n) for n in self.names) elif isinstance(obj, (list, tuple)): return tuple(obj) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(d.get(n) for n in self.names) else: raise ValueError("Unexpected tuple %r with StructType" % obj) From ad7f0f160be096c0fdae6e6cf7e3b6ba4a606de7 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 26 Aug 2015 18:13:07 -0700 Subject: [PATCH 0435/1824] [SPARK-10308] [SPARKR] Add %in% to the exported namespace I also checked all the other functions defined in column.R, functions.R and DataFrame.R and everything else looked fine. cc yu-iskw Author: Shivaram Venkataraman Closes #8473 from shivaram/in-namespace. --- R/pkg/NAMESPACE | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3e5c89d779b7..5286c0198620 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -47,12 +47,12 @@ exportMethods("arrange", "join", "limit", "merge", + "mutate", + "na.omit", "names", "ncol", "nrow", "orderBy", - "mutate", - "names", "persist", "printSchema", "rbind", @@ -82,7 +82,8 @@ exportMethods("arrange", exportClasses("Column") -exportMethods("abs", +exportMethods("%in%", + "abs", "acos", "add_months", "alias", From 773ca037a43d464ce7f16fe693ca6034f09a35b7 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 26 Aug 2015 18:14:32 -0700 Subject: [PATCH 0436/1824] [MINOR] [SPARKR] Fix some validation problems in SparkR Getting rid of some validation problems in SparkR https://github.com/apache/spark/pull/7883 cc shivaram ``` inst/tests/test_Serde.R:26:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_Serde.R:34:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_Serde.R:37:38: style: Trailing whitespace is superfluous. expect_equal(class(x), "character") ^~ inst/tests/test_Serde.R:50:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_Serde.R:55:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_Serde.R:60:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_sparkSQL.R:611:1: style: Trailing whitespace is superfluous. ^~ R/DataFrame.R:664:1: style: Trailing whitespace is superfluous. ^~~~~~~~~~~~~~ R/DataFrame.R:670:55: style: Trailing whitespace is superfluous. df <- data.frame(row.names = 1 : nrow) ^~~~~~~~~~~~~~~~ R/DataFrame.R:672:1: style: Trailing whitespace is superfluous. ^~~~~~~~~~~~~~ R/DataFrame.R:686:49: style: Trailing whitespace is superfluous. df[[names[colIndex]]] <- vec ^~~~~~~~~~~~~~~~~~ ``` Author: Yu ISHIKAWA Closes #8474 from yu-iskw/minor-fix-sparkr. --- R/pkg/R/DataFrame.R | 8 ++++---- R/pkg/inst/tests/test_Serde.R | 12 ++++++------ R/pkg/inst/tests/test_sparkSQL.R | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a5162de705f8..dd8126aebf46 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -661,15 +661,15 @@ setMethod("collect", # listCols is a list of columns listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) stopifnot(length(listCols) == ncol) - + # An empty data.frame with 0 columns and number of rows as collected nrow <- length(listCols[[1]]) if (nrow <= 0) { df <- data.frame() } else { - df <- data.frame(row.names = 1 : nrow) + df <- data.frame(row.names = 1 : nrow) } - + # Append columns one by one for (colIndex in 1 : ncol) { # Note: appending a column of list type into a data.frame so that @@ -683,7 +683,7 @@ setMethod("collect", # TODO: more robust check on column of primitive types vec <- do.call(c, col) if (class(vec) != "list") { - df[[names[colIndex]]] <- vec + df[[names[colIndex]]] <- vec } else { # For columns of complex type, be careful to access them. # Get a column of complex type returns a list. diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R index 009db85da2be..dddce54d7044 100644 --- a/R/pkg/inst/tests/test_Serde.R +++ b/R/pkg/inst/tests/test_Serde.R @@ -23,7 +23,7 @@ test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") - + x <- callJStatic("SparkRHandler", "echo", 1) expect_equal(x, 1) expect_equal(class(x), "numeric") @@ -31,10 +31,10 @@ test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", TRUE) expect_true(x) expect_equal(class(x), "logical") - + x <- callJStatic("SparkRHandler", "echo", "abc") expect_equal(x, "abc") - expect_equal(class(x), "character") + expect_equal(class(x), "character") }) test_that("SerDe of list of primitive types", { @@ -47,17 +47,17 @@ test_that("SerDe of list of primitive types", { y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) expect_equal(class(y[[1]]), "numeric") - + x <- list(TRUE, FALSE) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) expect_equal(class(y[[1]]), "logical") - + x <- list("a", "b", "c") y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) expect_equal(class(y[[1]]), "character") - + # Empty list x <- list() y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index ee48a3dc0cc0..8e22c56824b1 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -608,7 +608,7 @@ test_that("subsetting", { df4 <- df[df$age %in% c(19, 30), 1:2] expect_equal(count(df4), 2) expect_equal(columns(df4), c("name", "age")) - + df5 <- df[df$age %in% c(19), c(1,2)] expect_equal(count(df5), 1) expect_equal(columns(df5), c("name", "age")) From 0fac144f6bd835395059154532d72cdb5dc7ef8d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 26 Aug 2015 18:14:54 -0700 Subject: [PATCH 0437/1824] [SPARK-9424] [SQL] Parquet programming guide updates for 1.5 Author: Cheng Lian Closes #8467 from liancheng/spark-9424/parquet-docs-for-1.5. --- docs/sql-programming-guide.md | 45 ++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 33e7893d7bd0..e64190b9b209 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1124,6 +1124,13 @@ a simple schema, and gradually add more columns to the schema as needed. In thi up with multiple Parquet files with different but mutually compatible schemas. The Parquet data source is now able to automatically detect this case and merge schemas of all these files. +Since schema merging is a relatively expensive operation, and is not a necessity in most cases, we +turned it off by default starting from 1.5.0. You may enable it by + +1. setting data source option `mergeSchema` to `true` when reading Parquet files (as shown in the + examples below), or +2. setting the global SQL option `spark.sql.parquet.mergeSchema` to `true`. +
    @@ -1143,7 +1150,7 @@ val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.read.parquet("data/test_table") +val df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together @@ -1165,16 +1172,16 @@ df3.printSchema() # Create a simple DataFrame, stored into a partition directory df1 = sqlContext.createDataFrame(sc.parallelize(range(1, 6))\ .map(lambda i: Row(single=i, double=i * 2))) -df1.save("data/test_table/key=1", "parquet") +df1.write.parquet("data/test_table/key=1") # Create another DataFrame in a new partition directory, # adding a new column and dropping an existing column df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) .map(lambda i: Row(single=i, triple=i * 3))) -df2.save("data/test_table/key=2", "parquet") +df2.write.parquet("data/test_table/key=2") # Read the partitioned table -df3 = sqlContext.load("data/test_table", "parquet") +df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together @@ -1201,7 +1208,7 @@ saveDF(df1, "data/test_table/key=1", "parquet", "overwrite") saveDF(df2, "data/test_table/key=2", "parquet", "overwrite") # Read the partitioned table -df3 <- loadDF(sqlContext, "data/test_table", "parquet") +df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema="true") printSchema(df3) # The final schema consists of all 3 columns in the Parquet files together @@ -1301,7 +1308,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.binaryAsString false - Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do + Some other Parquet-producing systems, in particular Impala, Hive, and older versions of Spark SQL, do not differentiate between binary data and strings when writing out the Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. @@ -1310,8 +1317,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.int96AsTimestamp true - Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. Spark would also - store Timestamp as INT96 because we need to avoid precision lost of the nanoseconds field. This + Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. @@ -1355,6 +1361,9 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

    Note:

      +
    • + This option is automatically ignored if spark.speculation is turned on. +
    • This option must be set via Hadoop Configuration rather than Spark SQLConf. @@ -1371,6 +1380,26 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

      + + spark.sql.parquet.mergeSchema + false + +

      + When true, the Parquet data source merges schemas collected from all data files, otherwise the + schema is picked from the summary file or a random data file if no summary file is available. +

      + + + + spark.sql.parquet.mergeSchema + false + +

      + When true, the Parquet data source merges schemas collected from all data files, otherwise the + schema is picked from the summary file or a random data file if no summary file is available. +

      + + ## JSON Datasets From ce97834dc0cc55eece0e909a4061ca6f2123f60d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 26 Aug 2015 22:19:11 -0700 Subject: [PATCH 0438/1824] [SPARK-9964] [PYSPARK] [SQL] PySpark DataFrameReader accept RDD of String for JSON PySpark DataFrameReader should could accept an RDD of Strings (like the Scala version does) for JSON, rather than only taking a path. If this PR is merged, it should be duplicated to cover the other input types (not just JSON). Author: Yanbo Liang Closes #8444 from yanboliang/spark-9964. --- python/pyspark/sql/readwriter.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 78247c8fa737..3fa6895880a9 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -15,8 +15,14 @@ # limitations under the License. # +import sys + +if sys.version >= '3': + basestring = unicode = str + from py4j.java_gateway import JavaClass +from pyspark import RDD from pyspark.sql import since from pyspark.sql.column import _to_seq from pyspark.sql.types import * @@ -125,23 +131,33 @@ def load(self, path=None, format=None, schema=None, **options): @since(1.4) def json(self, path, schema=None): """ - Loads a JSON file (one object per line) and returns the result as - a :class`DataFrame`. + Loads a JSON file (one object per line) or an RDD of Strings storing JSON objects + (one object per record) and returns the result as a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. - :param path: string, path to the JSON dataset. + :param path: string represents path to the JSON dataset, + or RDD of Strings storing JSON objects. :param schema: an optional :class:`StructType` for the input schema. - >>> df = sqlContext.read.json('python/test_support/sql/people.json') - >>> df.dtypes + >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') + >>> df1.dtypes + [('age', 'bigint'), ('name', 'string')] + >>> rdd = sc.textFile('python/test_support/sql/people.json') + >>> df2 = sqlContext.read.json(rdd) + >>> df2.dtypes [('age', 'bigint'), ('name', 'string')] """ if schema is not None: self.schema(schema) - return self._df(self._jreader.json(path)) + if isinstance(path, basestring): + return self._df(self._jreader.json(path)) + elif isinstance(path, RDD): + return self._df(self._jreader.json(path._jrdd)) + else: + raise TypeError("path can be only string or RDD") @since(1.4) def table(self, tableName): From e936cf8088a06d6aefce44305f3904bbeb17b432 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 26 Aug 2015 22:27:31 -0700 Subject: [PATCH 0439/1824] [SPARK-10219] [SPARKR] Fix varargsToEnv and add test case cc sun-rui davies Author: Shivaram Venkataraman Closes #8475 from shivaram/varargs-fix. --- R/pkg/R/utils.R | 3 ++- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 4f9f4d9cad2a..3babcb519378 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -314,7 +314,8 @@ convertEnvsToList <- function(keys, vals) { # Utility function to capture the varargs into environment object varargsToEnv <- function(...) { - pairs <- as.list(substitute(list(...)))[-1L] + # Based on http://stackoverflow.com/a/3057419/4577954 + pairs <- list(...) env <- new.env() for (name in names(pairs)) { env[[name]] <- pairs[[name]] diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 8e22c56824b1..4b672e115f92 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1060,6 +1060,12 @@ test_that("parquetFile works with multiple input paths", { parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) expect_is(parquetDF, "DataFrame") expect_equal(count(parquetDF), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) }) test_that("describe() and summarize() on a DataFrame", { From de0278286cf6db8df53b0b68918ea114f2c77f1f Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Wed, 26 Aug 2015 23:12:55 -0700 Subject: [PATCH 0440/1824] =?UTF-8?q?[SPARK-10251]=20[CORE]=20some=20commo?= =?UTF-8?q?n=20types=20are=20not=20registered=20for=20Kryo=20Serializat?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ion by default Author: Ram Sriharsha Closes #8465 from harsha2010/SPARK-10251. --- .../spark/serializer/KryoSerializer.scala | 35 ++++++++++++++++++- .../serializer/KryoSerializerSuite.scala | 30 ++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 048a93850727..b977711e7d5a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import javax.annotation.Nullable import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException} @@ -38,7 +39,7 @@ import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} +import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer /** @@ -131,6 +132,38 @@ class KryoSerializer(conf: SparkConf) // our code override the generic serializers in Chill for things like Seq new AllScalaRegistrar().apply(kryo) + // Register types missed by Chill. + // scalastyle:off + kryo.register(classOf[Array[Tuple1[Any]]]) + kryo.register(classOf[Array[Tuple2[Any, Any]]]) + kryo.register(classOf[Array[Tuple3[Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple4[Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple5[Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple6[Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple7[Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple8[Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple9[Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + + // scalastyle:on + + kryo.register(None.getClass) + kryo.register(Nil.getClass) + kryo.register(Utils.classForName("scala.collection.immutable.$colon$colon")) + kryo.register(classOf[ArrayBuffer[Any]]) + kryo.setClassLoader(classLoader) kryo } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 8d1c9d17e977..e428414cf6e8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -150,6 +150,36 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { mutable.HashMap(1->"one", 2->"two", 3->"three"))) } + test("Bug: SPARK-10251") { + val ser = new KryoSerializer(conf.clone.set("spark.kryo.registrationRequired", "true")) + .newInstance() + def check[T: ClassTag](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check((1, 3)) + check(Array((1, 3))) + check(List((1, 3))) + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(1 -> 1) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1->"one", 2->"two", 3->"three"))) + } + test("ranges") { val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { From 9625d13d575c97bbff264f6a94838aae72c9202d Mon Sep 17 00:00:00 2001 From: Moussa Taifi Date: Thu, 27 Aug 2015 10:34:47 +0100 Subject: [PATCH 0441/1824] [DOCS] [STREAMING] [KAFKA] Fix typo in exactly once semantics Fix Typo in exactly once semantics [Semantics of output operations] link Author: Moussa Taifi Closes #8468 from moutai/patch-3. --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 7571e22575ef..5db39ae54a27 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -82,7 +82,7 @@ This approach has the following advantages over the receiver-based approach (i.e - *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semanitcs of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). From 1650f6f56ed4b7f1a7f645c9e8d5ac533464bd78 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 10:44:44 +0100 Subject: [PATCH 0442/1824] [SPARK-10254] [ML] Removes Guava dependencies in spark.ml.feature JavaTests * Replaces `com.google.common` dependencies with `java.util.Arrays` * Small clean up in `JavaNormalizerSuite` Author: Feynman Liang Closes #8445 from feynmanliang/SPARK-10254. --- .../apache/spark/ml/feature/JavaBucketizerSuite.java | 5 +++-- .../org/apache/spark/ml/feature/JavaDCTSuite.java | 5 +++-- .../apache/spark/ml/feature/JavaHashingTFSuite.java | 5 +++-- .../apache/spark/ml/feature/JavaNormalizerSuite.java | 11 +++++------ .../org/apache/spark/ml/feature/JavaPCASuite.java | 4 ++-- .../ml/feature/JavaPolynomialExpansionSuite.java | 5 +++-- .../spark/ml/feature/JavaStandardScalerSuite.java | 4 ++-- .../apache/spark/ml/feature/JavaTokenizerSuite.java | 6 ++++-- .../spark/ml/feature/JavaVectorIndexerSuite.java | 5 ++--- .../spark/ml/feature/JavaVectorSlicerSuite.java | 4 ++-- .../apache/spark/ml/feature/JavaWord2VecSuite.java | 11 ++++++----- 11 files changed, 35 insertions(+), 30 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index d5bd230a957a..47d68de599da 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -54,7 +55,7 @@ public void tearDown() { public void bucketizerTest() { double[] splits = {-0.5, 0.0, 0.5}; - JavaRDD data = jsc.parallelize(Lists.newArrayList( + JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), RowFactory.create(0.0), diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 845eed61c45c..0f6ec64d97d3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; import org.junit.After; import org.junit.Assert; @@ -56,7 +57,7 @@ public void tearDown() { @Test public void javaCompatibilityTest() { double[] input = new double[] {1D, 2D, 3D, 4D}; - JavaRDD data = jsc.parallelize(Lists.newArrayList( + JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create(Vectors.dense(input)) )); DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{ diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 599e9cfd23ad..03dd5369bddf 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -54,7 +55,7 @@ public void tearDown() { @Test public void hashingTF() { - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0.0, "Hi I heard about Spark"), RowFactory.create(0.0, "I wish Java could use case classes"), RowFactory.create(1.0, "Logistic regression models are neat") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index d82f3b7e8c07..e17d549c5059 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature; -import java.util.List; +import java.util.Arrays; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; @@ -48,13 +48,12 @@ public void tearDown() { @Test public void normalizer() { // The tests are to check Java compatibility. - List points = Lists.newArrayList( + JavaRDD points = jsc.parallelize(Arrays.asList( new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) - ); - DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), - VectorIndexerSuite.FeatureData.class); + )); + DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index 5cf43fec6f29..e8f329f9cf29 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -18,11 +18,11 @@ package org.apache.spark.ml.feature; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -78,7 +78,7 @@ public Vector getExpected() { @Test public void testPCA() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0}), Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index 5e8211c2c511..834fedbb59e1 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -59,7 +60,7 @@ public void polynomialExpansionTest() { .setOutputCol("polyFeatures") .setDegree(3); - JavaRDD data = jsc.parallelize(Lists.newArrayList( + JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create( Vectors.dense(-2.0, 2.3), Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index 74eb2733f06e..ed74363f59e3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -17,9 +17,9 @@ package org.apache.spark.ml.feature; +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,7 +48,7 @@ public void tearDown() { @Test public void standardScaler() { // The tests are to check Java compatibility. - List points = Lists.newArrayList( + List points = Arrays.asList( new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 3806f650025b..02309ce63219 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -54,7 +55,8 @@ public void regexTokenizer() { .setGaps(true) .setMinTokenLength(3); - JavaRDD rdd = jsc.parallelize(Lists.newArrayList( + + JavaRDD rdd = jsc.parallelize(Arrays.asList( new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) )); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index c7ae5468b942..bfcca62fa1c9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -26,8 +27,6 @@ import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; - import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.mllib.linalg.Vectors; @@ -52,7 +51,7 @@ public void tearDown() { @Test public void vectorIndexerAPI() { // The tests are to check Java compatibility. - List points = Lists.newArrayList( + List points = Arrays.asList( new FeatureData(Vectors.dense(0.0, -2.0)), new FeatureData(Vectors.dense(1.0, 3.0)), new FeatureData(Vectors.dense(1.0, 4.0)) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index 56988b9fb29c..f95336142758 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; import org.junit.After; import org.junit.Assert; @@ -63,7 +63,7 @@ public void vectorSlice() { }; AttributeGroup group = new AttributeGroup("userFeatures", attrs); - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) )); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index 39c70157f83c..70f5ad943221 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -50,10 +51,10 @@ public void tearDown() { @Test public void testJavaWord2Vec() { - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), - RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), - RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), + RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), + RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))) )); StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) From 75d62307946283b03bec6aaf1bdd4f2b08c93915 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 10:45:35 +0100 Subject: [PATCH 0443/1824] [SPARK-10255] [ML] Removes Guava dependencies from spark.ml.param JavaTests Author: Feynman Liang Closes #8446 from feynmanliang/SPARK-10255. --- .../java/org/apache/spark/ml/param/JavaParamsSuite.java | 7 ++++--- .../java/org/apache/spark/ml/param/JavaTestParams.java | 5 ++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index 9890155e9f86..fa777f3d42a9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.param; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -61,7 +62,7 @@ public void testParamValidate() { ParamValidators.ltEq(1.0); ParamValidators.inRange(0, 1, true, false); ParamValidators.inRange(0, 1); - ParamValidators.inArray(Lists.newArrayList(0, 1, 3)); - ParamValidators.inArray(Lists.newArrayList("a", "b")); + ParamValidators.inArray(Arrays.asList(0, 1, 3)); + ParamValidators.inArray(Arrays.asList("a", "b")); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index dc6ce8061f62..65841182df9b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -17,10 +17,9 @@ package org.apache.spark.ml.param; +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; - import org.apache.spark.ml.util.Identifiable$; /** @@ -89,7 +88,7 @@ private void init() { myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", ParamValidators.inRange(0.0, 1.0)); - List validStrings = Lists.newArrayList("a", "b"); + List validStrings = Arrays.asList("a", "b"); myStringParam_ = new Param(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); myDoubleArrayParam_ = From 1a446f75b6cac46caea0217a66abeb226946ac71 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 10:46:18 +0100 Subject: [PATCH 0444/1824] [SPARK-10256] [ML] Removes guava dependency from spark.ml.classification JavaTests Author: Feynman Liang Closes #8447 from feynmanliang/SPARK-10256. --- .../apache/spark/ml/classification/JavaNaiveBayesSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index a700c9cddb20..8fd7bf55a2e5 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -18,8 +18,8 @@ package org.apache.spark.ml.classification; import java.io.Serializable; +import java.util.Arrays; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -74,7 +74,7 @@ public void naiveBayesDefaultParams() { @Test public void testNaiveBayes() { - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)), RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)), RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)), From b02e8187225d1765f67ce38864dfaca487be8a44 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 27 Aug 2015 11:07:37 +0100 Subject: [PATCH 0445/1824] [SPARK-9613] [HOTFIX] Fix usage of JavaConverters removed in Scala 2.11 Fix for [JavaConverters.asJavaListConverter](http://www.scala-lang.org/api/2.10.5/index.html#scala.collection.JavaConverters$) being removed in 2.11.7 and hence the build fails with the 2.11 profile enabled. Tested with the default 2.10 and 2.11 profiles. BUILD SUCCESS in both cases. Build for 2.10: ./build/mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.7.1 -DskipTests clean install and 2.11: ./dev/change-scala-version.sh 2.11 ./build/mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.7.1 -Dscala-2.11 -DskipTests clean install Author: Jacek Laskowski Closes #8479 from jaceklaskowski/SPARK-9613-hotfix. --- .../org/apache/spark/ml/classification/JavaOneVsRestSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index 2744e020e9e4..253cabf0133d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -55,7 +55,7 @@ public void setUp() { double[] xMean = {5.843, 3.057, 3.758, 1.199}; double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List points = JavaConverters.asJavaListConverter( + List points = JavaConverters.seqAsJavaListConverter( generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) ).asJava(); datasetRDD = jsc.parallelize(points, 2); From e1f4de4a7d15d4ca4b5c64ff929ac3980f5d706f Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 18:46:41 +0100 Subject: [PATCH 0446/1824] [SPARK-10257] [MLLIB] Removes Guava from all spark.mllib Java tests * Replaces instances of `Lists.newArrayList` with `Arrays.asList` * Replaces `commons.lang.StringUtils` over `com.google.collections.Strings` * Replaces `List` interface over `ArrayList` implementations This PR along with #8445 #8446 #8447 completely removes all `com.google.collections.Lists` dependencies within mllib's Java tests. Author: Feynman Liang Closes #8451 from feynmanliang/SPARK-10257. --- .../JavaStreamingLogisticRegressionSuite.java | 10 +++---- .../clustering/JavaGaussianMixtureSuite.java | 4 +-- .../mllib/clustering/JavaKMeansSuite.java | 9 +++---- .../clustering/JavaStreamingKMeansSuite.java | 10 +++---- .../spark/mllib/feature/JavaTfIdfSuite.java | 19 +++++++------ .../mllib/feature/JavaWord2VecSuite.java | 6 ++--- .../mllib/fpm/JavaAssociationRulesSuite.java | 5 ++-- .../spark/mllib/fpm/JavaFPGrowthSuite.java | 17 ++++++------ .../spark/mllib/linalg/JavaVectorsSuite.java | 5 ++-- .../mllib/random/JavaRandomRDDsSuite.java | 27 ++++++++++--------- .../mllib/recommendation/JavaALSSuite.java | 5 ++-- .../JavaIsotonicRegressionSuite.java | 7 ++--- .../JavaStreamingLinearRegressionSuite.java | 10 +++---- .../spark/mllib/stat/JavaStatisticsSuite.java | 11 ++++---- 14 files changed, 71 insertions(+), 74 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index 55787f8606d4..c9e5ee22f327 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.classification; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -60,16 +60,16 @@ public void tearDown() { @Test @SuppressWarnings("unchecked") public void javaAPI() { - List trainingBatch = Lists.newArrayList( + List trainingBatch = Arrays.asList( new LabeledPoint(1.0, Vectors.dense(1.0)), new LabeledPoint(0.0, Vectors.dense(0.0))); JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + List> testBatch = Arrays.asList( new Tuple2(10, Vectors.dense(1.0)), new Tuple2(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() .setNumIterations(2) .setInitialWeights(Vectors.dense(0.0)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java index 467a7a69e8f3..123f78da54e3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -18,9 +18,9 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,7 +48,7 @@ public void tearDown() { @Test public void runGaussianMixture() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index 31676e64025d..ad06676c72ac 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import org.junit.After; @@ -25,8 +26,6 @@ import org.junit.Test; import static org.junit.Assert.*; -import com.google.common.collect.Lists; - import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; @@ -48,7 +47,7 @@ public void tearDown() { @Test public void runKMeansUsingStaticMethods() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) @@ -67,7 +66,7 @@ public void runKMeansUsingStaticMethods() { @Test public void runKMeansUsingConstructor() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) @@ -90,7 +89,7 @@ public void runKMeansUsingConstructor() { @Test public void testPredictJavaRDD() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index 3b0e879eec77..d644766d1e54 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -60,16 +60,16 @@ public void tearDown() { @Test @SuppressWarnings("unchecked") public void javaAPI() { - List trainingBatch = Lists.newArrayList( + List trainingBatch = Arrays.asList( Vectors.dense(1.0), Vectors.dense(0.0)); JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + List> testBatch = Arrays.asList( new Tuple2(10, Vectors.dense(1.0)), new Tuple2(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingKMeans skmeans = new StreamingKMeans() .setK(1) .setDecayFactor(1.0) diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index fbc26167ce66..8a320afa4b13 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -18,14 +18,13 @@ package org.apache.spark.mllib.feature; import java.io.Serializable; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -50,10 +49,10 @@ public void tfIdf() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("this is a sentence".split(" ")), - Lists.newArrayList("this is another sentence".split(" ")), - Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD> documents = sc.parallelize(Arrays.asList( + Arrays.asList("this is a sentence".split(" ")), + Arrays.asList("this is another sentence".split(" ")), + Arrays.asList("this is still a sentence".split(" "))), 2); JavaRDD termFreqs = tf.transform(documents); termFreqs.collect(); IDF idf = new IDF(); @@ -70,10 +69,10 @@ public void tfIdfMinimumDocumentFrequency() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("this is a sentence".split(" ")), - Lists.newArrayList("this is another sentence".split(" ")), - Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD> documents = sc.parallelize(Arrays.asList( + Arrays.asList("this is a sentence".split(" ")), + Arrays.asList("this is another sentence".split(" ")), + Arrays.asList("this is still a sentence".split(" "))), 2); JavaRDD termFreqs = tf.transform(documents); termFreqs.collect(); IDF idf = new IDF(2); diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java index fb7afe8c6434..e13ed07e283d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.feature; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import com.google.common.base.Strings; import org.junit.After; import org.junit.Assert; @@ -51,8 +51,8 @@ public void tearDown() { public void word2Vec() { // The tests are to check Java compatibility. String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); - List words = Lists.newArrayList(sentence.split(" ")); - List> localDoc = Lists.newArrayList(words, words); + List words = Arrays.asList(sentence.split(" ")); + List> localDoc = Arrays.asList(words, words); JavaRDD> doc = sc.parallelize(localDoc); Word2Vec word2vec = new Word2Vec() .setVectorSize(10) diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index d7c2cb3ae206..2bef7a860975 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -17,17 +17,16 @@ package org.apache.spark.mllib.fpm; import java.io.Serializable; +import java.util.Arrays; import org.junit.After; import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; - public class JavaAssociationRulesSuite implements Serializable { private transient JavaSparkContext sc; @@ -46,7 +45,7 @@ public void tearDown() { public void runAssociationRules() { @SuppressWarnings("unchecked") - JavaRDD> freqItemsets = sc.parallelize(Lists.newArrayList( + JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( new FreqItemset(new String[] {"a"}, 15L), new FreqItemset(new String[] {"b"}, 35L), new FreqItemset(new String[] {"a", "b"}, 12L) diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 9ce2c52dca8b..154f75d75e4a 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -18,13 +18,12 @@ package org.apache.spark.mllib.fpm; import java.io.Serializable; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; @@ -48,13 +47,13 @@ public void tearDown() { public void runFPGrowth() { @SuppressWarnings("unchecked") - JavaRDD> rdd = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("r z h k p".split(" ")), - Lists.newArrayList("z y x w v u t s".split(" ")), - Lists.newArrayList("s x o n r".split(" ")), - Lists.newArrayList("x z y m t s q e".split(" ")), - Lists.newArrayList("z".split(" ")), - Lists.newArrayList("x z y r q t p".split(" "))), 2); + JavaRDD> rdd = sc.parallelize(Arrays.asList( + Arrays.asList("r z h k p".split(" ")), + Arrays.asList("z y x w v u t s".split(" ")), + Arrays.asList("s x o n r".split(" ")), + Arrays.asList("x z y m t s q e".split(" ")), + Arrays.asList("z".split(" ")), + Arrays.asList("x z y r q t p".split(" "))), 2); FPGrowthModel model = new FPGrowth() .setMinSupport(0.5) diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 1421067dc61e..77c8c6274f37 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -18,11 +18,10 @@ package org.apache.spark.mllib.linalg; import java.io.Serializable; +import java.util.Arrays; import scala.Tuple2; -import com.google.common.collect.Lists; - import org.junit.Test; import static org.junit.Assert.*; @@ -37,7 +36,7 @@ public void denseArrayConstruction() { @Test public void sparseArrayConstruction() { @SuppressWarnings("unchecked") - Vector v = Vectors.sparse(3, Lists.>newArrayList( + Vector v = Vectors.sparse(3, Arrays.asList( new Tuple2(0, 2.0), new Tuple2(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index fcc13c00cbdc..33d81b1e9592 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.mllib.random; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.apache.spark.api.java.JavaRDD; import org.junit.Assert; import org.junit.After; @@ -51,7 +52,7 @@ public void testUniformRDD() { JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -64,7 +65,7 @@ public void testNormalRDD() { JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -79,7 +80,7 @@ public void testLNormalRDD() { JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m); JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p); JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -93,7 +94,7 @@ public void testPoissonRDD() { JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -107,7 +108,7 @@ public void testExponentialRDD() { JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m); JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p); JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -122,7 +123,7 @@ public void testGammaRDD() { JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m); JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p); JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -138,7 +139,7 @@ public void testUniformVectorRDD() { JavaRDD rdd1 = uniformJavaVectorRDD(sc, m, n); JavaRDD rdd2 = uniformJavaVectorRDD(sc, m, n, p); JavaRDD rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -154,7 +155,7 @@ public void testNormalVectorRDD() { JavaRDD rdd1 = normalJavaVectorRDD(sc, m, n); JavaRDD rdd2 = normalJavaVectorRDD(sc, m, n, p); JavaRDD rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -172,7 +173,7 @@ public void testLogNormalVectorRDD() { JavaRDD rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n); JavaRDD rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p); JavaRDD rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -189,7 +190,7 @@ public void testPoissonVectorRDD() { JavaRDD rdd1 = poissonJavaVectorRDD(sc, mean, m, n); JavaRDD rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); JavaRDD rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -206,7 +207,7 @@ public void testExponentialVectorRDD() { JavaRDD rdd1 = exponentialJavaVectorRDD(sc, mean, m, n); JavaRDD rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p); JavaRDD rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -224,7 +225,7 @@ public void testGammaVectorRDD() { JavaRDD rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n); JavaRDD rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p); JavaRDD rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index af688c504cf1..271dda4662e0 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -18,12 +18,12 @@ package org.apache.spark.mllib.recommendation; import java.io.Serializable; +import java.util.ArrayList; import java.util.List; import scala.Tuple2; import scala.Tuple3; -import com.google.common.collect.Lists; import org.jblas.DoubleMatrix; import org.junit.After; import org.junit.Assert; @@ -56,8 +56,7 @@ void validatePrediction( double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { - List> localUsersProducts = - Lists.newArrayListWithCapacity(users * products); + List> localUsersProducts = new ArrayList(users * products); for (int u=0; u < users; ++u) { for (int p=0; p < products; ++p) { localUsersProducts.add(new Tuple2(u, p)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index d38fc91ace3c..32c2f4f3395b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -18,11 +18,12 @@ package org.apache.spark.mllib.regression; import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import scala.Tuple3; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -36,7 +37,7 @@ public class JavaIsotonicRegressionSuite implements Serializable { private transient JavaSparkContext sc; private List> generateIsotonicInput(double[] labels) { - List> input = Lists.newArrayList(); + ArrayList> input = new ArrayList(labels.length); for (int i = 1; i <= labels.length; i++) { input.add(new Tuple3(labels[i-1], (double) i, 1d)); @@ -77,7 +78,7 @@ public void testIsotonicRegressionPredictionsJavaRDD() { IsotonicRegressionModel model = runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); - JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0)); + JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); List predictions = model.predict(testRDD).collect(); Assert.assertTrue(predictions.get(0) == 1d); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index 899c4ea60786..dbf6488d4108 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.regression; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -59,16 +59,16 @@ public void tearDown() { @Test @SuppressWarnings("unchecked") public void javaAPI() { - List trainingBatch = Lists.newArrayList( + List trainingBatch = Arrays.asList( new LabeledPoint(1.0, Vectors.dense(1.0)), new LabeledPoint(0.0, Vectors.dense(0.0))); JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + List> testBatch = Arrays.asList( new Tuple2(10, Vectors.dense(1.0)), new Tuple2(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() .setNumIterations(2) .setInitialWeights(Vectors.dense(0.0)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index eb4e3698624b..4795809e47a4 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -19,7 +19,8 @@ import java.io.Serializable; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -50,8 +51,8 @@ public void tearDown() { @Test public void testCorr() { - JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0)); - JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3)); + JavaRDD x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + JavaRDD y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3)); Double corr1 = Statistics.corr(x, y); Double corr2 = Statistics.corr(x, y, "pearson"); @@ -61,7 +62,7 @@ public void testCorr() { @Test public void kolmogorovSmirnovTest() { - JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0)); + JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0)); KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( data, "norm", 0.0, 1.0); @@ -69,7 +70,7 @@ public void kolmogorovSmirnovTest() { @Test public void chiSqTest() { - JavaRDD data = sc.parallelize(Lists.newArrayList( + JavaRDD data = sc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); From fdd466bed7a7151dd066d732ef98d225f4acda4a Mon Sep 17 00:00:00 2001 From: Vyacheslav Baranov Date: Thu, 27 Aug 2015 18:56:18 +0100 Subject: [PATCH 0447/1824] [SPARK-10182] [MLLIB] GeneralizedLinearModel doesn't unpersist cached data `GeneralizedLinearModel` creates a cached RDD when building a model. It's inconvenient, since these RDDs flood the memory when building several models in a row, so useful data might get evicted from the cache. The proposed solution is to always cache the dataset & remove the warning. There's a caveat though: input dataset gets evaluated twice, in line 270 when fitting `StandardScaler` for the first time, and when running optimizer for the second time. So, it might worth to return removed warning. Another possible solution is to disable caching entirely & return removed warning. I don't really know what approach is better. Author: Vyacheslav Baranov Closes #8395 from SlavikBaranov/SPARK-10182. --- .../spark/mllib/regression/GeneralizedLinearAlgorithm.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 7e3b4d5648fe..8f657bfb9c73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -359,6 +359,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] + " parent RDDs are also uncached.") } + // Unpersist cached data + if (data.getStorageLevel != StorageLevel.NONE) { + data.unpersist(false) + } + createModel(weights, intercept) } } From dc86a227e4fc8a9d8c3e8c68da8dff9298447fd0 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 27 Aug 2015 11:45:15 -0700 Subject: [PATCH 0448/1824] [SPARK-9148] [SPARK-10252] [SQL] Update SQL Programming Guide Author: Michael Armbrust Closes #8441 from marmbrus/documentation. --- docs/sql-programming-guide.md | 92 +++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 19 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e64190b9b209..99fec6c7785a 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -11,7 +11,7 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. -For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. +Spark SQL can also be used to read data from an existing Hive installation. For more on how to configure this feature, please refer to the [Hive Tables](#hive-tables) section. # DataFrames @@ -213,6 +213,11 @@ df.groupBy("age").count().show() // 30 1 {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.DataFrame). + +
    @@ -263,6 +268,10 @@ df.groupBy("age").count().show(); // 30 1 {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/java/org/apache/spark/sql/DataFrame.html). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). +
    @@ -320,6 +329,10 @@ df.groupBy("age").count().show() {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/python/pyspark.sql.html#pyspark.sql.DataFrame). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). +
    @@ -370,10 +383,13 @@ showDF(count(groupBy(df, "age"))) {% endhighlight %} -
    +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html).
    + ## Running SQL Queries Programmatically @@ -870,12 +886,11 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet") Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if present. It is important to realize that these save modes do not utilize any locking and are not -atomic. Thus, it is not safe to have multiple writers attempting to write to the same location. -Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the new data. - + @@ -1671,12 +1686,12 @@ results <- collect(sql(sqlContext, "FROM src SELECT key, value")) ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary +build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. +Note that independent of the version of Hive that is being used to talk to the metastore, internally Spark SQL +will compile against Hive 1.2.1 and use those classes for internal execution (serdes, UDFs, UDAFs, etc). -Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` -and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive -jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the -version specified by users. An isolated classloader is used here to avoid dependency conflicts. +The following options can be used to configure the version of Hive that is used to retrieve metadata:
    Scala/JavaPythonMeaning
    Scala/JavaAny LanguageMeaning
    SaveMode.ErrorIfExists (default) "error" (default)
    @@ -1685,7 +1700,7 @@ version specified by users. An isolated classloader is used here to avoid depend @@ -1696,12 +1711,16 @@ version specified by users. An isolated classloader is used here to avoid depend property can be one of three options:
    1. builtin
    2. - Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + Use Hive 1.2.1, which is bundled with the Spark assembly jar when -Phive is enabled. When this option is chosen, spark.sql.hive.metastore.version must be - either 0.13.1 or not defined. + either 1.2.1 or not defined.
    3. maven
    4. - Use Hive jars of specified version downloaded from Maven repositories. -
    5. A classpath in the standard format for both Hive and Hadoop.
    6. + Use Hive jars of specified version downloaded from Maven repositories. This configuration + is not generally recommended for production deployments. +
    7. A classpath in the standard format for the JVM. This classpath must include all of Hive + and its dependencies, including the correct version of Hadoop. These jars only need to be + present on the driver, but if you are running in yarn cluster mode then you must ensure + they are packaged with you application.
    @@ -2017,6 +2036,28 @@ options. # Migration Guide +## Upgrading From Spark SQL 1.4 to 1.5 + + - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with + code generation for expression evaluation. These features can both be disabled by setting + `spark.sql.tungsten.enabled` to `false. + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + `spark.sql.parquet.mergeSchema` to `true`. + - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or + access nested values. For example `df['table.column.nestedField']`. However, this means that if + your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). + - In-memory columnar storage partition pruning is on by default. It can be disabled by setting + `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. + - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum + precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now + used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`. + - Timestamps are now stored at a precision of 1us, rather than 1ns + - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains + unchanged. + - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). + - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe + and thus this output committer will not be used when speculation is on, independent of configuration. + ## Upgrading from Spark SQL 1.3 to 1.4 #### DataFrame data reader/writer interface @@ -2038,7 +2079,8 @@ See the API docs for `SQLContext.read` ( #### DataFrame.groupBy retains grouping columns -Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`. +Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the +grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
    @@ -2175,7 +2217,7 @@ Python UDF registration is unchanged. When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of referencing a singleton. -## Migration Guide for Shark User +## Migration Guide for Shark Users ### Scheduling To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, @@ -2251,6 +2293,7 @@ Spark SQL supports the vast majority of Hive features, such as: * User defined functions (UDF) * User defined aggregation functions (UDAF) * User defined serialization formats (SerDes) +* Window functions * Joins * `JOIN` * `{LEFT|RIGHT|FULL} OUTER JOIN` @@ -2261,7 +2304,7 @@ Spark SQL supports the vast majority of Hive features, such as: * `SELECT col FROM ( SELECT a + b AS col from t1) t2` * Sampling * Explain -* Partitioned tables +* Partitioned tables including dynamic partition insertion * View * All Hive DDL Functions, including: * `CREATE TABLE` @@ -2323,8 +2366,9 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. +# Reference -# Data Types +## Data Types Spark SQL and DataFrames support the following data types: @@ -2937,3 +2981,13 @@ from pyspark.sql.types import *
    +## NaN Semantics + +There is specially handling for not-a-number (NaN) when dealing with `float` or `double` types that +does not exactly match standard floating point semantics. +Specifically: + + - NaN = NaN returns true. + - In aggregations all NaN values are grouped together. + - NaN is treated as a normal value in join keys. + - NaN values go last when in ascending order, larger than any other numeric value. From 84baa5e9b5edc8c55871fbed5057324450bf097f Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 27 Aug 2015 20:19:09 +0100 Subject: [PATCH 0449/1824] [SPARK-10315] remove document on spark.akka.failure-detector.threshold https://issues.apache.org/jira/browse/SPARK-10315 this parameter is not used any longer and there is some mistake in the current document , should be 'akka.remote.watch-failure-detector.threshold' Author: CodingCat Closes #8483 from CodingCat/SPARK_10315. --- docs/configuration.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 4a6e4dd05b66..77c5cbc7b319 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -906,16 +906,6 @@ Apart from these, the following properties are also available, and may be useful #### Networking
    Property NameDefaultMeaning
    0.13.1 Version of the Hive metastore. Available - options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. + options are 0.12.0 through 1.2.1.
    - - - - - From 6185cdd2afcd492b77ff225b477b3624e3bc7bb2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 27 Aug 2015 13:57:20 -0700 Subject: [PATCH 0450/1824] [SPARK-9901] User guide for RowMatrix Tall-and-skinny QR jira: https://issues.apache.org/jira/browse/SPARK-9901 The jira covers only the document update. I can further provide example code for QR (like the ones for SVD and PCA) in a separate PR. Author: Yuhao Yang Closes #8462 from hhbyyh/qrDoc. --- docs/mllib-data-types.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index f0e8d5495675..065bf4727624 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -337,7 +337,10 @@ limited by the integer range but it should be much smaller in practice.
    A [`RowMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) can be -created from an `RDD[Vector]` instance. Then we can compute its column summary statistics. +created from an `RDD[Vector]` instance. Then we can compute its column summary statistics and decompositions. +[QR decomposition](https://en.wikipedia.org/wiki/QR_decomposition) is of the form A = QR where Q is an orthogonal matrix and R is an upper triangular matrix. +For [singular value decomposition (SVD)](https://en.wikipedia.org/wiki/Singular_value_decomposition) and [principal component analysis (PCA)](https://en.wikipedia.org/wiki/Principal_component_analysis), please refer to [Dimensionality reduction](mllib-dimensionality-reduction.html). + {% highlight scala %} import org.apache.spark.mllib.linalg.Vector @@ -350,6 +353,9 @@ val mat: RowMatrix = new RowMatrix(rows) // Get its size. val m = mat.numRows() val n = mat.numCols() + +// QR decomposition +val qrResult = mat.tallSkinnyQR(true) {% endhighlight %}
    @@ -370,6 +376,9 @@ RowMatrix mat = new RowMatrix(rows.rdd()); // Get its size. long m = mat.numRows(); long n = mat.numCols(); + +// QR decomposition +QRDecomposition result = mat.tallSkinnyQR(true); {% endhighlight %} From c94ecdfc5b3c0fe6c38a170dc2af9259354dc9e3 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 27 Aug 2015 15:33:43 -0700 Subject: [PATCH 0451/1824] [SPARK-9906] [ML] User guide for LogisticRegressionSummary User guide for LogisticRegression summaries Author: MechCoder Author: Manoj Kumar Author: Feynman Liang Closes #8197 from MechCoder/log_summary_user_guide. --- docs/ml-linear-methods.md | 149 ++++++++++++++++++++++++++++++++++---- 1 file changed, 133 insertions(+), 16 deletions(-) diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 1ac83d94c9e8..2761aeb78962 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -23,20 +23,41 @@ displayTitle: ML - Linear Methods \]` -In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: +In MLlib, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods in mllib](mllib-linear-methods.html) for +details. In `spark.ml`, we also include Pipelines API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: `\[ -\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. +\alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0. \]` -By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. - -**Examples** +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. + +## Example: Logistic Regression + +The following example shows how to train a logistic regression model +with elastic net regularization. `elasticNetParam` corresponds to +$\alpha$ and `regParam` corresponds to $\lambda$.
    - {% highlight scala %} - import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.mllib.util.MLUtils @@ -53,15 +74,11 @@ val lrModel = lr.fit(training) // Print the weights and intercept for logistic regression println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") - {% endhighlight %} -
    - {% highlight java %} - import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.mllib.regression.LabeledPoint; @@ -99,9 +116,7 @@ public class LogisticRegressionWithElasticNetExample {
    - {% highlight python %} - from pyspark.ml.classification import LogisticRegression from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.util import MLUtils @@ -118,12 +133,114 @@ lrModel = lr.fit(training) print("Weights: " + str(lrModel.weights)) print("Intercept: " + str(lrModel.intercept)) {% endhighlight %} +
    +The `spark.ml` implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as `Dataframe` in +`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +only available on the driver. + +
    + +
    + +[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) +provides a summary for a +[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% highlight scala %} +// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example +val trainingSummary = lrModel.summary + +// Obtain the loss per iteration. +val objectiveHistory = trainingSummary.objectiveHistory +objectiveHistory.foreach(loss => println(loss)) + +// Obtain the metrics useful to judge performance on test data. +// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a +// binary classification problem. +val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] + +// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. +val roc = binarySummary.roc +roc.show() +roc.select("FPR").show() +println(binarySummary.areaUnderROC) + +// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with +// this selected threshold. +val fMeasure = binarySummary.fMeasureByThreshold +val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) +val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). + select("threshold").head().getDouble(0) +logReg.setThreshold(bestThreshold) +logReg.fit(logRegDataFrame) +{% endhighlight %}
    -### Optimization +
    +[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) +provides a summary for a +[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% highlight java %} +// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example +LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary(); + +// Obtain the loss per iteration. +double[] objectiveHistory = trainingSummary.objectiveHistory(); +for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); +} + +// Obtain the metrics useful to judge performance on test data. +// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a +// binary classification problem. +BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary; + +// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. +DataFrame roc = binarySummary.roc(); +roc.show(); +roc.select("FPR").show(); +System.out.println(binarySummary.areaUnderROC()); + +// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with +// this selected threshold. +DataFrame fMeasure = binarySummary.fMeasureByThreshold(); +double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0); +double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). + select("threshold").head().getDouble(0); +logReg.setThreshold(bestThreshold); +logReg.fit(logRegDataFrame); +{% endhighlight %} +
    + +
    +Logistic regression model summary is not yet supported in Python. +
    + +
    + +# Optimization + +The optimization algorithm underlying the implementation is called +[Orthant-Wise Limited-memory +QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 +regularization and elastic net. -The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) -(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. From 5bfe9e1111d9862084586549a7dc79476f67bab9 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 16:10:37 -0700 Subject: [PATCH 0452/1824] [SPARK-9680] [MLLIB] [DOC] StopWordsRemovers user guide and Java compatibility test * Adds user guide for ml.feature.StopWordsRemovers, ran code examples on my machine * Cleans up scaladocs for public methods * Adds test for Java compatibility * Follow up Python user guide code example is tracked by SPARK-10249 Author: Feynman Liang Closes #8436 from feynmanliang/SPARK-10230. --- docs/ml-features.md | 102 +++++++++++++++++- .../ml/feature/JavaStopWordsRemoverSuite.java | 72 +++++++++++++ 2 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java diff --git a/docs/ml-features.md b/docs/ml-features.md index 62de4838981c..89a9bad57044 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -306,15 +306,111 @@ regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern= +## StopWordsRemover +[Stop words](https://en.wikipedia.org/wiki/Stop_words) are words which +should be excluded from the input, typically because the words appear +frequently and don't carry as much meaning. + +`StopWordsRemover` takes as input a sequence of strings (e.g. the output +of a [Tokenizer](ml-features.html#tokenizer)) and drops all the stop +words from the input sequences. The list of stopwords is specified by +the `stopWords` parameter. We provide [a list of stop +words](http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words) by +default, accessible by calling `getStopWords` on a newly instantiated +`StopWordsRemover` instance. -## $n$-gram +**Examples** -An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (typically words) for some integer $n$. The `NGram` class can be used to transform input features into $n$-grams. +Assume that we have the following DataFrame with columns `id` and `raw`: -`NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. +~~~~ + id | raw +----|---------- + 0 | [I, saw, the, red, baloon] + 1 | [Mary, had, a, little, lamb] +~~~~ + +Applying `StopWordsRemover` with `raw` as the input column and `filtered` as the output +column, we should get the following: + +~~~~ + id | raw | filtered +----|-----------------------------|-------------------- + 0 | [I, saw, the, red, baloon] | [saw, red, baloon] + 1 | [Mary, had, a, little, lamb]|[Mary, little, lamb] +~~~~ + +In `filtered`, the stop words "I", "the", "had", and "a" have been +filtered out.
    +
    + +[`StopWordsRemover`](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) +takes an input column name, an output column name, a list of stop words, +and a boolean indicating if the matches should be case sensitive (false +by default). + +{% highlight scala %} +import org.apache.spark.ml.feature.StopWordsRemover + +val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") +val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) +)).toDF("id", "raw") + +remover.transform(dataSet).show() +{% endhighlight %} +
    + +
    + +[`StopWordsRemover`](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) +takes an input column name, an output column name, a list of stop words, +and a boolean indicating if the matches should be case sensitive (false +by default). + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) +)); +StructType schema = new StructType(new StructField[] { + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame dataset = jsql.createDataFrame(rdd, schema); + +remover.transform(dataset).show(); +{% endhighlight %} +
    +
    + +## $n$-gram + +An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (typically words) for some integer $n$. The `NGram` class can be used to transform input features into $n$-grams. + +`NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer)). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. +
    diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java new file mode 100644 index 000000000000..76cdd0fae84a --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + + +public class JavaStopWordsRemoverSuite { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void javaCompatibilityTest() { + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + StructType schema = new StructType(new StructField[] { + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + DataFrame dataset = jsql.createDataFrame(rdd, schema); + + remover.transform(dataset).collect(); + } +} From b3dd569ad40905f8861a547a1e25ed3ca8e1d272 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 27 Aug 2015 16:11:25 -0700 Subject: [PATCH 0453/1824] [SPARK-10287] [SQL] Fixes JSONRelation refreshing on read path https://issues.apache.org/jira/browse/SPARK-10287 After porting json to HadoopFsRelation, it seems hard to keep the behavior of picking up new files automatically for JSON. This PR removes this behavior, so JSON is consistent with others (ORC and Parquet). Author: Yin Huai Closes #8469 from yhuai/jsonRefresh. --- docs/sql-programming-guide.md | 6 ++++++ .../execution/datasources/json/JSONRelation.scala | 9 --------- .../org/apache/spark/sql/sources/interfaces.scala | 2 +- .../apache/spark/sql/sources/InsertSuite.scala | 15 --------------- 4 files changed, 7 insertions(+), 25 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 99fec6c7785a..e8eb88488ee2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2057,6 +2057,12 @@ options. - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe and thus this output committer will not be used when speculation is on, independent of configuration. + - JSON data source will not automatically load new files that are created by other applications + (i.e. files that are not inserted to the dataset through Spark SQL). + For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore), + users can use `REFRESH TABLE` SQL command or `HiveContext`'s `refreshTable` method + to include those new files to the table. For a DataFrame representing a JSON dataset, users need to recreate + the DataFrame and the new DataFrame will include new files. ## Upgrading from Spark SQL 1.3 to 1.4 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 114c8b211891..ab8ca5f748f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -111,15 +111,6 @@ private[sql] class JSONRelation( jsonSchema } - override private[sql] def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { - refresh() - super.buildScan(requiredColumns, filters, inputPaths, broadcastedConf) - } - override def buildScan( requiredColumns: Array[String], filters: Array[Filter], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3b326fe612c..dff726b33fc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -562,7 +562,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio }) } - private[sql] def buildScan( + final private[sql] def buildScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 78bd3e558296..084d83f6e9bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -167,21 +167,6 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } - test("save directly to the path of a JSON table") { - caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") - .write.mode(SaveMode.Overwrite).json(path.toString) - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i * 5, s"str$i")) - ) - - caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, s"str$i")) - ) - } - test("it is not allowed to write to a table while querying it.") { val message = intercept[AnalysisException] { sql( From 54cda0deb6bebf1470f16ba5bcc6c4fb842bdac1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 27 Aug 2015 16:38:00 -0700 Subject: [PATCH 0454/1824] [SPARK-10321] sizeInBytes in HadoopFsRelation Having sizeInBytes in HadoopFsRelation to enable broadcast join. cc marmbrus Author: Davies Liu Closes #8490 from davies/sizeInByte. --- .../main/scala/org/apache/spark/sql/sources/interfaces.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index dff726b33fc7..7b030b7d73bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -518,6 +518,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + override def sizeInBytes: Long = cachedLeafStatuses().map(_.getLen).sum + /** * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically * discovered. Note that they should always be nullable. From 1f90c5e2198bcf49e115d97ec300c17c1be4dcb4 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 27 Aug 2015 19:38:53 -0700 Subject: [PATCH 0455/1824] [SPARK-8505] [SPARKR] Add settings to kick `lint-r` from `./dev/run-test.py` JoshRosen we'd like to check the SparkR source code with the `dev/lint-r` script on the Jenkins. I tried to incorporate the script into `dev/run-test.py`. Could you review it when you have time? shivaram I modified `dev/lint-r` and `dev/lint-r.R` to install lintr package into a local directory(`R/lib/`) and to exit with a lint status. Could you review it? - [[SPARK-8505] Add settings to kick `lint-r` from `./dev/run-test.py` - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8505) Author: Yu ISHIKAWA Closes #7883 from yu-iskw/SPARK-8505. --- dev/lint-r | 11 +++++++++++ dev/lint-r.R | 12 +++++++----- dev/run-tests-codes.sh | 13 +++++++------ dev/run-tests-jenkins | 2 ++ dev/run-tests.py | 21 ++++++++++++++++++++- 5 files changed, 47 insertions(+), 12 deletions(-) diff --git a/dev/lint-r b/dev/lint-r index 7d5f4cd31153..c15d57aad86d 100755 --- a/dev/lint-r +++ b/dev/lint-r @@ -28,3 +28,14 @@ if ! type "Rscript" > /dev/null; then fi `which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME" + +NUM_LINES=`wc -l < "$LINT_R_REPORT_FILE_NAME"` +if [ "$NUM_LINES" = "0" ] ; then + lint_status=0 + echo "lintr checks passed." +else + lint_status=1 + echo "lintr checks failed." +fi + +exit "$lint_status" diff --git a/dev/lint-r.R b/dev/lint-r.R index 48bd6246096a..999eef571b82 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -17,8 +17,14 @@ argv <- commandArgs(TRUE) SPARK_ROOT_DIR <- as.character(argv[1]) +LOCAL_LIB_LOC <- file.path(SPARK_ROOT_DIR, "R", "lib") -# Installs lintr from Github. +# Checks if SparkR is installed in a local directory. +if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) { + stop("You should install SparkR in a local directory with `R/install-dev.sh`.") +} + +# Installs lintr from Github in a local directory. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { devtools::install_github("jimhester/lintr") @@ -27,9 +33,5 @@ if ("lintr" %in% row.names(installed.packages()) == FALSE) { library(lintr) library(methods) library(testthat) -if (! library(SparkR, lib.loc = file.path(SPARK_ROOT_DIR, "R", "lib"), logical.return = TRUE)) { - stop("You should install SparkR in a local directory with `R/install-dev.sh`.") -} - path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") lint_package(path.to.package, cache = FALSE) diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index f4b238e1b78a..1f16790522e7 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -21,9 +21,10 @@ readonly BLOCK_GENERAL=10 readonly BLOCK_RAT=11 readonly BLOCK_SCALA_STYLE=12 readonly BLOCK_PYTHON_STYLE=13 -readonly BLOCK_DOCUMENTATION=14 -readonly BLOCK_BUILD=15 -readonly BLOCK_MIMA=16 -readonly BLOCK_SPARK_UNIT_TESTS=17 -readonly BLOCK_PYSPARK_UNIT_TESTS=18 -readonly BLOCK_SPARKR_UNIT_TESTS=19 +readonly BLOCK_R_STYLE=14 +readonly BLOCK_DOCUMENTATION=15 +readonly BLOCK_BUILD=16 +readonly BLOCK_MIMA=17 +readonly BLOCK_SPARK_UNIT_TESTS=18 +readonly BLOCK_PYSPARK_UNIT_TESTS=19 +readonly BLOCK_SPARKR_UNIT_TESTS=20 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f144c053046c..39cf54f78104 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -210,6 +210,8 @@ done failing_test="Scala style tests" elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then failing_test="Python style tests" + elif [ "$test_result" -eq "$BLOCK_R_STYLE" ]; then + failing_test="R style tests" elif [ "$test_result" -eq "$BLOCK_DOCUMENTATION" ]; then failing_test="to generate documentation" elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then diff --git a/dev/run-tests.py b/dev/run-tests.py index f689425ee40b..4fd703a7c219 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -209,6 +209,18 @@ def run_python_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) +def run_sparkr_style_checks(): + set_title_and_block("Running R style checks", "BLOCK_R_STYLE") + + if which("R"): + # R style check should be executed after `install-dev.sh`. + # Since warnings about `no visible global function definition` appear + # without the installation. SEE ALSO: SPARK-9121. + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-r")]) + else: + print("Ignoring SparkR style check as R was not found in PATH") + + def build_spark_documentation(): set_title_and_block("Building Spark Documentation", "BLOCK_DOCUMENTATION") os.environ["PRODUCTION"] = "1 jekyll build" @@ -387,7 +399,6 @@ def run_sparkr_tests(): set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") if which("R"): - run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) else: print("Ignoring SparkR tests as R was not found in PATH") @@ -438,6 +449,12 @@ def main(): if java_version.minor < 8: print("[warn] Java 8 tests will not run because JDK version is < 1.8.") + # install SparkR + if which("R"): + run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) + else: + print("Can't install SparkR as R is was not found in PATH") + if os.environ.get("AMPLAB_JENKINS"): # if we're on the Amplab Jenkins build servers setup variables # to reflect the environment settings @@ -485,6 +502,8 @@ def main(): run_scala_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() + if not changed_files or any(f.endswith(".R") for f in changed_files): + run_sparkr_style_checks() # determine if docs were changed and if we're inside the amplab environment # note - the below commented out until *all* Jenkins workers can get `jekyll` installed From 30734d45fbbb269437c062241a9161e198805a76 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 27 Aug 2015 21:44:06 -0700 Subject: [PATCH 0456/1824] [SPARK-9911] [DOC] [ML] Update Userguide for Evaluator I added a small note about the different types of evaluator and the metrics used. Author: MechCoder Closes #8304 from MechCoder/multiclass_evaluator. --- docs/ml-guide.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index de8fead3529e..01bf5ee18e32 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -643,6 +643,13 @@ An important task in ML is *model selection*, or using data to find the best mod Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. `CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. + +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) +for binary data or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the setMetric +method in each of these evaluators. + The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. `CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. @@ -708,9 +715,12 @@ val pipeline = new Pipeline() // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. // This will allow us to jointly choose parameters for all Pipeline stages. // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric +// used is areaUnderROC. val crossval = new CrossValidator() .setEstimator(pipeline) .setEvaluator(new BinaryClassificationEvaluator) + // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -831,9 +841,12 @@ Pipeline pipeline = new Pipeline() // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. // This will allow us to jointly choose parameters for all Pipeline stages. // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric +// used is areaUnderROC. CrossValidator crossval = new CrossValidator() .setEstimator(pipeline) .setEvaluator(new BinaryClassificationEvaluator()); + // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. From af0e1249b1c881c0fa7a921fd21fd2c27214b980 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 21:55:20 -0700 Subject: [PATCH 0457/1824] [SPARK-9905] [ML] [DOC] Adds LinearRegressionSummary user guide * Adds user guide for `LinearRegressionSummary` * Fixes unresolved issues in #8197 CC jkbradley mengxr Author: Feynman Liang Closes #8491 from feynmanliang/SPARK-9905. --- docs/ml-linear-methods.md | 140 ++++++++++++++++++++++++++++++++++---- 1 file changed, 127 insertions(+), 13 deletions(-) diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 2761aeb78962..cdd9d4999fa1 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -34,7 +34,7 @@ net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically, it is defined as a convex combination of the $L_1$ and the $L_2$ regularization terms: `\[ -\alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0. +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 \]` By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ regularization as special cases. For example, if a [linear @@ -95,7 +95,7 @@ public class LogisticRegressionWithElasticNetExample { SparkContext sc = new SparkContext(conf); SQLContext sql = new SQLContext(sc); - String path = "sample_libsvm_data.txt"; + String path = "data/mllib/sample_libsvm_data.txt"; // Load training data DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); @@ -103,7 +103,7 @@ public class LogisticRegressionWithElasticNetExample { LogisticRegression lr = new LogisticRegression() .setMaxIter(10) .setRegParam(0.3) - .setElasticNetParam(0.8) + .setElasticNetParam(0.8); // Fit the model LogisticRegressionModel lrModel = lr.fit(training); @@ -158,10 +158,12 @@ This will likely change when multiclass classification is supported. Continuing the earlier example: {% highlight scala %} +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example val trainingSummary = lrModel.summary -// Obtain the loss per iteration. +// Obtain the objective per iteration. val objectiveHistory = trainingSummary.objectiveHistory objectiveHistory.foreach(loss => println(loss)) @@ -173,17 +175,14 @@ val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. val roc = binarySummary.roc roc.show() -roc.select("FPR").show() println(binarySummary.areaUnderROC) -// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with -// this selected threshold. +// Set the model threshold to maximize F-Measure val fMeasure = binarySummary.fMeasureByThreshold val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). select("threshold").head().getDouble(0) -logReg.setThreshold(bestThreshold) -logReg.fit(logRegDataFrame) +lrModel.setThreshold(bestThreshold) {% endhighlight %}
    @@ -199,8 +198,12 @@ This will likely change when multiclass classification is supported. Continuing the earlier example: {% highlight java %} +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.sql.functions; + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example -LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary(); +LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); // Obtain the loss per iteration. double[] objectiveHistory = trainingSummary.objectiveHistory(); @@ -222,20 +225,131 @@ System.out.println(binarySummary.areaUnderROC()); // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with // this selected threshold. DataFrame fMeasure = binarySummary.fMeasureByThreshold(); -double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0); +double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). select("threshold").head().getDouble(0); -logReg.setThreshold(bestThreshold); -logReg.fit(logRegDataFrame); +lrModel.setThreshold(bestThreshold); {% endhighlight %}
    +
    Logistic regression model summary is not yet supported in Python.
    +## Example: Linear Regression + +The interface for working with linear regression models and model +summaries is similar to the logistic regression case. The following +example demonstrates training an elastic net regularized linear +regression model and extracting model summary statistics. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.MLUtils + +// Load training data +val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +val lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + +// Fit the model +val lrModel = lr.fit(training) + +// Print the weights and intercept for linear regression +println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") + +// Summarize the model over the training set and print out some metrics +val trainingSummary = lrModel.summary +println(s"numIterations: ${trainingSummary.totalIterations}") +println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}") +trainingSummary.residuals.show() +println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") +println(s"r2: ${trainingSummary.r2}") +{% endhighlight %} +
    + +
    +{% highlight java %} +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.regression.LinearRegressionModel; +import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class LinearRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Linear Regression with Elastic Net Example"); + + SparkContext sc = new SparkContext(conf); + SQLContext sql = new SQLContext(sc); + String path = "data/mllib/sample_libsvm_data.txt"; + + // Load training data + DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LinearRegressionModel lrModel = lr.fit(training); + + // Print the weights and intercept for linear regression + System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + + // Summarize the model over the training set and print out some metrics + LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); + System.out.println("numIterations: " + trainingSummary.totalIterations()); + System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory())); + trainingSummary.residuals().show(); + System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError()); + System.out.println("r2: " + trainingSummary.r2()); + } +} +{% endhighlight %} +
    + +
    + +{% highlight python %} +from pyspark.ml.regression import LinearRegression +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Load training data +training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + +# Fit the model +lrModel = lr.fit(training) + +# Print the weights and intercept for linear regression +print("Weights: " + str(lrModel.weights)) +print("Intercept: " + str(lrModel.intercept)) + +# Linear regression model summary is not yet supported in Python. +{% endhighlight %} +
    + +
    + # Optimization The optimization algorithm underlying the implementation is called From 89b943438512fcfb239c268b43431397de46cbcf Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 27 Aug 2015 22:30:01 -0700 Subject: [PATCH 0458/1824] [SPARK-SQL] [MINOR] Fixes some typos in HiveContext Author: Cheng Lian Closes #8481 from liancheng/hive-context-typo. --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 8 ++++---- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c0a458fa9ab8..2e791cea96b4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -171,11 +171,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. * - allow SQL11 keywords to be used as identifiers */ - private[sql] def defaultOverides() = { + private[sql] def defaultOverrides() = { setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") } - defaultOverides() + defaultOverrides() /** * The copy of the Hive client that is used to retrieve metadata from the Hive MetaStore. @@ -190,8 +190,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // into the isolated client loader val metadataConf = new HiveConf() - val defaltWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") - logInfo("defalt warehouse location is " + defaltWarehouseLocation) + val defaultWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") + logInfo("default warehouse location is " + defaultWarehouseLocation) // `configure` goes second to override other settings. val allConfig = metadataConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configure diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 572eaebe81ac..57fea5d8db34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -434,7 +434,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { case (k, v) => metadataHive.runSqlHive(s"SET $k=$v") } - defaultOverides() + defaultOverrides() runSqlHive("USE default") From 7583681e6b0824d7eed471dc4d8fa0b2addf9ffc Mon Sep 17 00:00:00 2001 From: noelsmith Date: Thu, 27 Aug 2015 23:59:30 -0700 Subject: [PATCH 0459/1824] [SPARK-10188] [PYSPARK] Pyspark CrossValidator with RMSE selects incorrect model * Added isLargerBetter() method to Pyspark Evaluator to match the Scala version. * JavaEvaluator delegates isLargerBetter() to underlying Scala object. * Added check for isLargerBetter() in CrossValidator to determine whether to use argmin or argmax. * Added test cases for where smaller is better (RMSE) and larger is better (R-Squared). (This contribution is my original work and that I license the work to the project under Sparks' open source license) Author: noelsmith Closes #8399 from noel-smith/pyspark-rmse-xval-fix. --- python/pyspark/ml/evaluation.py | 12 +++++ python/pyspark/ml/tests.py | 87 +++++++++++++++++++++++++++++++++ python/pyspark/ml/tuning.py | 6 ++- 3 files changed, 104 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 6b0a9ffde9f4..cb3b07947e48 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -66,6 +66,14 @@ def evaluate(self, dataset, params=None): else: raise ValueError("Params must be a param map but got %s." % type(params)) + def isLargerBetter(self): + """ + Indicates whether the metric returned by :py:meth:`evaluate` should be maximized + (True, default) or minimized (False). + A given evaluator may support multiple metrics which may be maximized or minimized. + """ + return True + @inherit_doc class JavaEvaluator(Evaluator, JavaWrapper): @@ -85,6 +93,10 @@ def _evaluate(self, dataset): self._transfer_params_to_java() return self._java_obj.evaluate(dataset._jdf) + def isLargerBetter(self): + self._transfer_params_to_java() + return self._java_obj.isLargerBetter() + @inherit_doc class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c151d21fd661..60e4237293ad 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -32,11 +32,14 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame, SQLContext +from pyspark.sql.functions import rand +from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.util import keyword_only from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.feature import * +from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel from pyspark.mllib.linalg import DenseVector @@ -264,5 +267,89 @@ def test_ngram(self): self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) +class HasInducedError(Params): + + def __init__(self): + super(HasInducedError, self).__init__() + self.inducedError = Param(self, "inducedError", + "Uniformly-distributed error added to feature") + + def getInducedError(self): + return self.getOrDefault(self.inducedError) + + +class InducedErrorModel(Model, HasInducedError): + + def __init__(self): + super(InducedErrorModel, self).__init__() + + def _transform(self, dataset): + return dataset.withColumn("prediction", + dataset.feature + (rand(0) * self.getInducedError())) + + +class InducedErrorEstimator(Estimator, HasInducedError): + + def __init__(self, inducedError=1.0): + super(InducedErrorEstimator, self).__init__() + self._set(inducedError=inducedError) + + def _fit(self, dataset): + model = InducedErrorModel() + self._copyValues(model) + return model + + +class CrossValidatorTests(PySparkTestCase): + + def test_fit_minimize_metric(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + + def test_fit_maximize_metric(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index dcfee6a3170a..cae778869e9c 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -223,7 +223,11 @@ def _fit(self, dataset): # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric - bestIndex = np.argmax(metrics) + + if eva.isLargerBetter(): + bestIndex = np.argmax(metrics) + else: + bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) From 2f99c37273c1d82e2ba39476e4429ea4aaba7ec6 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 28 Aug 2015 00:37:50 -0700 Subject: [PATCH 0460/1824] [SPARK-10328] [SPARKR] Fix generic for na.omit S3 function is at https://stat.ethz.ch/R-manual/R-patched/library/stats/html/na.fail.html Author: Shivaram Venkataraman Author: Shivaram Venkataraman Author: Yu ISHIKAWA Closes #8495 from shivaram/na-omit-fix. --- R/pkg/R/DataFrame.R | 6 +++--- R/pkg/R/generics.R | 2 +- R/pkg/inst/tests/test_sparkSQL.R | 23 ++++++++++++++++++++++- dev/lint-r | 2 +- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index dd8126aebf46..74de7c81e35a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1699,9 +1699,9 @@ setMethod("dropna", #' @name na.omit #' @export setMethod("na.omit", - signature(x = "DataFrame"), - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { - dropna(x, how, minNonNulls, cols) + signature(object = "DataFrame"), + function(object, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(object, how, minNonNulls, cols) }) #' fillna diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index a829d46c1894..b578b8789d2c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -413,7 +413,7 @@ setGeneric("dropna", #' @rdname nafunctions #' @export setGeneric("na.omit", - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + function(object, ...) { standardGeneric("na.omit") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 4b672e115f92..933b11c8ee7e 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1083,7 +1083,7 @@ test_that("describe() and summarize() on a DataFrame", { expect_equal(collect(stats2)[5, "age"], "30") }) -test_that("dropna() on a DataFrame", { +test_that("dropna() and na.omit() on a DataFrame", { df <- jsonFile(sqlContext, jsonPathNa) rows <- collect(df) @@ -1092,6 +1092,8 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$name),] actual <- collect(dropna(df, cols = "name")) expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = "name")) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age),] actual <- collect(dropna(df, cols = "age")) @@ -1101,48 +1103,67 @@ test_that("dropna() on a DataFrame", { expect_identical(expected$age, actual$age) expect_identical(expected$height, actual$height) expect_identical(expected$name, actual$name) + actual <- collect(na.omit(df, cols = "age")) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, cols = c("age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = c("age", "height"))) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) # drop with how expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] actual <- collect(dropna(df, "all")) expect_identical(expected, actual) + actual <- collect(na.omit(df, "all")) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df, "any")) expect_identical(expected, actual) + actual <- collect(na.omit(df, "any")) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, "any", cols = c("age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height),] actual <- collect(dropna(df, "all", cols = c("age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) # drop with threshold expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + as.integer(!is.na(rows$name)) >= 3,] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) }) test_that("fillna() on a DataFrame", { diff --git a/dev/lint-r b/dev/lint-r index c15d57aad86d..bfda0bca15eb 100755 --- a/dev/lint-r +++ b/dev/lint-r @@ -29,7 +29,7 @@ fi `which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME" -NUM_LINES=`wc -l < "$LINT_R_REPORT_FILE_NAME"` +NUM_LINES=`wc -l < "$LINT_R_REPORT_FILE_NAME" | awk '{print $1}'` if [ "$NUM_LINES" = "0" ] ; then lint_status=0 echo "lintr checks passed." From 4eeda8d472498acd40ef54723d1be9924a273d76 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 28 Aug 2015 00:50:26 -0700 Subject: [PATCH 0461/1824] [SPARK-10260] [ML] Add @Since annotation to ml.clustering ### JIRA [[SPARK-10260] Add Since annotation to ml.clustering - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-10260) Author: Yu ISHIKAWA Closes #8455 from yu-iskw/SPARK-10260. --- .../apache/spark/ml/clustering/KMeans.scala | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 47a18cdb31b5..f40ab71fb22a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} @@ -39,9 +39,11 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * Set the number of clusters to create (k). Must be > 1. Default: 2. * @group param */ + @Since("1.5.0") final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) /** @group getParam */ + @Since("1.5.0") def getK: Int = $(k) /** @@ -50,10 +52,12 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. * @group expertParam */ + @Since("1.5.0") final val initMode = new Param[String](this, "initMode", "initialization algorithm", (value: String) => MLlibKMeans.validateInitMode(value)) /** @group expertGetParam */ + @Since("1.5.0") def getInitMode: String = $(initMode) /** @@ -61,10 +65,12 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5. * @group expertParam */ + @Since("1.5.0") final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||", (value: Int) => value > 0) /** @group expertGetParam */ + @Since("1.5.0") def getInitSteps: Int = $(initSteps) /** @@ -84,27 +90,32 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * * @param parentModel a model trained by spark.mllib.clustering.KMeans. */ +@Since("1.5.0") @Experimental class KMeansModel private[ml] ( - override val uid: String, + @Since("1.5.0") override val uid: String, private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { val copied = new KMeansModel(uid, parentModel) copyValues(copied, extra) } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + @Since("1.5.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters } @@ -114,8 +125,11 @@ class KMeansModel private[ml] ( * * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]] */ +@Since("1.5.0") @Experimental -class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { +class KMeans @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) + extends Estimator[KMeansModel] with KMeansParams { setDefault( k -> 2, @@ -124,34 +138,45 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean initSteps -> 5, tol -> 1e-4) + @Since("1.5.0") override def copy(extra: ParamMap): KMeans = defaultCopy(extra) + @Since("1.5.0") def this() = this(Identifiable.randomUID("kmeans")) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** @group setParam */ + @Since("1.5.0") def setK(value: Int): this.type = set(k, value) /** @group expertSetParam */ + @Since("1.5.0") def setInitMode(value: String): this.type = set(initMode, value) /** @group expertSetParam */ + @Since("1.5.0") def setInitSteps(value: Int): this.type = set(initSteps, value) /** @group setParam */ + @Since("1.5.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ + @Since("1.5.0") def setTol(value: Double): this.type = set(tol, value) /** @group setParam */ + @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + @Since("1.5.0") override def fit(dataset: DataFrame): KMeansModel = { val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } @@ -167,6 +192,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean copyValues(model) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } From cc39803062119c1d14611dc227b9ed0ed1284d38 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 28 Aug 2015 09:32:23 +0100 Subject: [PATCH 0462/1824] [SPARK-10295] [CORE] Dynamic allocation in Mesos does not release when RDDs are cached Remove obsolete warning about dynamic allocation not working with cached RDDs See discussion in https://issues.apache.org/jira/browse/SPARK-10295 Author: Sean Owen Closes #8489 from srowen/SPARK-10295. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f3da04a7f55d..738887076b0d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1590,11 +1590,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { - _executorAllocationManager.foreach { _ => - logWarning( - s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + - s"${rdd.id} will be lost when executors are removed.") - } persistentRdds(rdd.id) = rdd } From 18294cd8710427076caa86bfac596de67089d57e Mon Sep 17 00:00:00 2001 From: Keiji Yoshida Date: Fri, 28 Aug 2015 09:36:50 +0100 Subject: [PATCH 0463/1824] Fix DynamodDB/DynamoDB typo in Kinesis Integration doc Fix DynamodDB/DynamoDB typo in Kinesis Integration doc Author: Keiji Yoshida Closes #8501 from yosssi/patch-1. --- docs/streaming-kinesis-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index a7bcaec6fcd8..238a911a9199 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -91,7 +91,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - Kinesis data processing is ordered per partition and occurs at-least once per message. - - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamodDB. + - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamoDB. - A single Kinesis stream shard is processed by one input DStream at a time. From 71a077f6c16c8816eae13341f645ba50d997f63d Mon Sep 17 00:00:00 2001 From: Dharmesh Kakadia Date: Fri, 28 Aug 2015 09:38:35 +0100 Subject: [PATCH 0464/1824] typo in comment Author: Dharmesh Kakadia Closes #8497 from dharmeshkakadia/patch-2. --- .../apache/spark/network/shuffle/protocol/RegisterExecutor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index cca8b17c4f12..167ef3310422 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -27,7 +27,7 @@ /** * Initial registration message between an executor and its local shuffle server. - * Returns nothing (empty bye array). + * Returns nothing (empty byte array). */ public class RegisterExecutor extends BlockTransferMessage { public final String appId; From 1502a0f6c5d2f85a331b29d3bf50002911ea393e Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 28 Aug 2015 09:32:52 -0500 Subject: [PATCH 0465/1824] [YARN] [MINOR] Avoid hard code port number in YarnShuffleService test Current port number is fixed as default (7337) in test, this will introduce port contention exception, better to change to a random number in unit test. squito , seems you're author of this unit test, mind taking a look at this fix? Thanks a lot. ``` [info] - executor state kept across NM restart *** FAILED *** (597 milliseconds) [info] org.apache.hadoop.service.ServiceStateException: java.net.BindException: Address already in use [info] at org.apache.hadoop.service.ServiceStateException.convert(ServiceStateException.java:59) [info] at org.apache.hadoop.service.AbstractService.init(AbstractService.java:172) [info] at org.apache.spark.network.yarn.YarnShuffleServiceSuite$$anonfun$1.apply$mcV$sp(YarnShuffleServiceSuite.scala:72) [info] at org.apache.spark.network.yarn.YarnShuffleServiceSuite$$anonfun$1.apply(YarnShuffleServiceSuite.scala:70) [info] at org.apache.spark.network.yarn.YarnShuffleServiceSuite$$anonfun$1.apply(YarnShuffleServiceSuite.scala:70) [info] at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) [info] at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) [info] at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) [info] at org.scalatest.Transformer.apply(Transformer.scala:22) [info] at org.scalatest.Transformer.apply(Transformer.scala:20) [info] at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) [info] at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:42) ... ``` Author: jerryshao Closes #8502 from jerryshao/avoid-hardcode-port. --- .../org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index 2f22cbdbeac3..6aa8c814cd4f 100644 --- a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -37,6 +37,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), classOf[YarnShuffleService].getCanonicalName) + yarnConfig.setInt("spark.shuffle.service.port", 0) yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir => val d = new File(dir) From e2a843090cb031f6aa774f6d9c031a7f0f732ee1 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 28 Aug 2015 08:00:44 -0700 Subject: [PATCH 0466/1824] [SPARK-9890] [DOC] [ML] User guide for CountVectorizer jira: https://issues.apache.org/jira/browse/SPARK-9890 document with Scala and java examples Author: Yuhao Yang Closes #8487 from hhbyyh/cvDoc. --- docs/ml-features.md | 109 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 89a9bad57044..90654d1e5a24 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -211,6 +211,115 @@ for feature in result.select("result").take(3): +## CountVectorizer + +`CountVectorizer` and `CountVectorizerModel` aim to help convert a collection of text documents + to vectors of token counts. When an a-priori dictionary is not available, `CountVectorizer` can + be used as an `Estimator` to extract the vocabulary and generates a `CountVectorizerModel`. The + model produces sparse representations for the documents over the vocabulary, which can then be + passed to other algorithms like LDA. + + During the fitting process, `CountVectorizer` will select the top `vocabSize` words ordered by + term frequency across the corpus. An optional parameter "minDF" also affect the fitting process + by specifying the minimum number (or fraction if < 1.0) of documents a term must appear in to be + included in the vocabulary. + +**Examples** + +Assume that we have the following DataFrame with columns `id` and `texts`: + +~~~~ + id | texts +----|---------- + 0 | Array("a", "b", "c") + 1 | Array("a", "b", "b", "c", "a") +~~~~ + +each row in`texts` is a document of type Array[String]. +Invoking fit of `CountVectorizer` produces a `CountVectorizerModel` with vocabulary (a, b, c), +then the output column "vector" after transformation contains: + +~~~~ + id | texts | vector +----|---------------------------------|--------------- + 0 | Array("a", "b", "c") | (3,[0,1,2],[1.0,1.0,1.0]) + 1 | Array("a", "b", "b", "c", "a") | (3,[0,1,2],[2.0,2.0,1.0]) +~~~~ + +each vector represents the token counts of the document over the vocabulary. + +
    +
    +More details can be found in the API docs for +[CountVectorizer](api/scala/index.html#org.apache.spark.ml.feature.CountVectorizer) and +[CountVectorizerModel](api/scala/index.html#org.apache.spark.ml.feature.CountVectorizerModel). +{% highlight scala %} +import org.apache.spark.ml.feature.CountVectorizer +import org.apache.spark.mllib.util.CountVectorizerModel + +val df = sqlContext.createDataFrame(Seq( + (0, Array("a", "b", "c")), + (1, Array("a", "b", "b", "c", "a")) +)).toDF("id", "words") + +// fit a CountVectorizerModel from the corpus +val cvModel: CountVectorizerModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) + .setMinDF(2) // a term must appear in more or equal to 2 documents to be included in the vocabulary + .fit(df) + +// alternatively, define CountVectorizerModel with a-priori vocabulary +val cvm = new CountVectorizerModel(Array("a", "b", "c")) + .setInputCol("words") + .setOutputCol("features") + +cvModel.transform(df).select("features").show() +{% endhighlight %} +
    + +
    +More details can be found in the API docs for +[CountVectorizer](api/java/org/apache/spark/ml/feature/CountVectorizer.html) and +[CountVectorizerModel](api/java/org/apache/spark/ml/feature/CountVectorizerModel.html). +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.CountVectorizer; +import org.apache.spark.ml.feature.CountVectorizerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +// Input data: Each row is a bag of words from a sentence or document. +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("a", "b", "c")), + RowFactory.create(Arrays.asList("a", "b", "b", "c", "a")) +)); +StructType schema = new StructType(new StructField [] { + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); + +// fit a CountVectorizerModel from the corpus +CountVectorizerModel cvModel = new CountVectorizer() + .setInputCol("text") + .setOutputCol("feature") + .setVocabSize(3) + .setMinDF(2) // a term must appear in more or equal to 2 documents to be included in the vocabulary + .fit(df); + +// alternatively, define CountVectorizerModel with a-priori vocabulary +CountVectorizerModel cvm = new CountVectorizerModel(new String[]{"a", "b", "c"}) + .setInputCol("text") + .setOutputCol("feature"); + +cvModel.transform(df).show(); +{% endhighlight %} +
    +
    + # Feature Transformers ## Tokenizer From 499e8e154bdcc9d7b2f685b159e0ddb4eae48fe4 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Fri, 28 Aug 2015 09:13:21 -0700 Subject: [PATCH 0467/1824] [SPARK-8952] [SPARKR] - Wrap normalizePath calls with suppressWarnings This is based on davies comment on SPARK-8952 which suggests to only call normalizePath() when path starts with '~' Author: Luciano Resende Closes #8343 from lresende/SPARK-8952. --- R/pkg/R/SQLContext.R | 4 ++-- R/pkg/R/sparkR.R | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 110117a18ccb..1bc644531147 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -201,7 +201,7 @@ setMethod("toDF", signature(x = "RDD"), jsonFile <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - path <- normalizePath(path) + path <- suppressWarnings(normalizePath(path)) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") sdf <- callJMethod(sqlContext, "jsonFile", path) @@ -251,7 +251,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { # TODO: Implement saveasParquetFile and write examples for both parquetFile <- function(sqlContext, ...) { # Allow the user to have a more flexible definiton of the text file path - paths <- lapply(list(...), normalizePath) + paths <- lapply(list(...), function(x) suppressWarnings(normalizePath(x))) sdf <- callJMethod(sqlContext, "parquetFile", paths) dataFrame(sdf) } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index e83104f11642..3c57a44db257 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -160,7 +160,7 @@ sparkR.init <- function( }) if (nchar(sparkHome) != 0) { - sparkHome <- normalizePath(sparkHome) + sparkHome <- suppressWarnings(normalizePath(sparkHome)) } sparkEnvirMap <- new.env() From d3f87dc39480f075170817bbd00142967a938078 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 28 Aug 2015 11:51:42 -0700 Subject: [PATCH 0468/1824] [SPARK-10325] Override hashCode() for public Row This commit fixes an issue where the public SQL `Row` class did not override `hashCode`, causing it to violate the hashCode() + equals() contract. To fix this, I simply ported the `hashCode` implementation from the 1.4.x version of `Row`. Author: Josh Rosen Closes #8500 from JoshRosen/SPARK-10325 and squashes the following commits: 51ffea1 [Josh Rosen] Override hashCode() for public Row. --- .../src/main/scala/org/apache/spark/sql/Row.scala | 13 +++++++++++++ .../test/scala/org/apache/spark/sql/RowSuite.scala | 9 +++++++++ 2 files changed, 22 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index cfd9cb0e6259..ed2fdf9f2f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow @@ -410,6 +411,18 @@ trait Row extends Serializable { true } + override def hashCode: Int = { + // Using Scala's Seq hash code implementation. + var n = 0 + var h = MurmurHash3.seqSeed + val len = length + while (n < len) { + h = MurmurHash3.mix(h, apply(n).##) + n += 1 + } + MurmurHash3.finalizeHash(h, n) + } + /* ---------------------- utility methods for Scala ---------------------- */ /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 795d4e983f27..77ccd6f775e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -85,4 +85,13 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { val r2 = Row(Double.NaN) assert(r1 === r2) } + + test("equals and hashCode") { + val r1 = Row("Hello") + val r2 = Row("Hello") + assert(r1 === r2) + assert(r1.hashCode() === r2.hashCode()) + val r3 = Row("World") + assert(r3.hashCode() != r1.hashCode()) + } } From c53c902fa9c458200245f919067b41dde9cd9418 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 28 Aug 2015 12:33:40 -0700 Subject: [PATCH 0469/1824] [SPARK-9284] [TESTS] Allow all tests to run without an assembly. This change aims at speeding up the dev cycle a little bit, by making sure that all tests behave the same w.r.t. where the code to be tested is loaded from. Namely, that means that tests don't rely on the assembly anymore, rather loading all needed classes from the build directories. The main change is to make sure all build directories (classes and test-classes) are added to the classpath of child processes when running tests. YarnClusterSuite required some custom code since the executors are run differently (i.e. not through the launcher library, like standalone and Mesos do). I also found a couple of tests that could leak a SparkContext on failure, and added code to handle those. With this patch, it's possible to run the following command from a clean source directory and have all tests pass: mvn -Pyarn -Phadoop-2.4 -Phive-thriftserver install Author: Marcelo Vanzin Closes #7629 from vanzin/SPARK-9284. --- bin/spark-class | 16 ++++---- .../spark/launcher/SparkLauncherSuite.java | 0 .../spark/broadcast/BroadcastSuite.scala | 10 ++++- .../launcher/AbstractCommandBuilder.java | 28 ++++++++------ pom.xml | 8 ++++ project/SparkBuild.scala | 2 + .../deploy/yarn/BaseYarnClusterSuite.scala | 37 +++++++++++++------ .../spark/deploy/yarn/YarnClusterSuite.scala | 20 ++++++++-- .../yarn/YarnShuffleIntegrationSuite.scala | 2 +- .../spark/launcher/TestClasspathBuilder.scala | 36 ++++++++++++++++++ 10 files changed, 122 insertions(+), 37 deletions(-) rename {launcher => core}/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java (100%) create mode 100644 yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala diff --git a/bin/spark-class b/bin/spark-class index 2b59e5df5736..e38e08dec40e 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -43,17 +43,19 @@ else fi num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" -if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" ]; then +if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" -a "$SPARK_PREPEND_CLASSES" != "1" ]; then echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 echo "You need to build Spark before running this program." 1>&2 exit 1 fi -ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" -if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 - echo "$ASSEMBLY_JARS" 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 +if [ -d "$ASSEMBLY_DIR" ]; then + ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" + if [ "$num_jars" -gt "1" ]; then + echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 + echo "$ASSEMBLY_JARS" 1>&2 + echo "Please remove all but one jar." 1>&2 + exit 1 + fi fi SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java similarity index 100% rename from launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java rename to core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 48e74f06f79b..fb7a8ae3f9d4 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -310,8 +310,14 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val _sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) - _sc + try { + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) + _sc + } catch { + case e: Throwable => + _sc.stop() + throw e + } } else { new SparkContext("local", "test", broadcastConf) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 5e793a5c4877..0a237ee73b67 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -169,9 +169,11 @@ List buildClassPath(String appClassPath) throws IOException { "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", "yarn", "launcher"); if (prependClasses) { - System.err.println( - "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + - "assembly."); + if (!isTesting) { + System.err.println( + "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + + "assembly."); + } for (String project : projects) { addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project, scala)); @@ -200,7 +202,7 @@ List buildClassPath(String appClassPath) throws IOException { // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. // That duplicates some of the code in the shell scripts that look for the assembly, though. String assembly = getenv(ENV_SPARK_ASSEMBLY); - if (assembly == null && isEmpty(getenv("SPARK_TESTING"))) { + if (assembly == null && !isTesting) { assembly = findAssembly(); } addToClassPath(cp, assembly); @@ -215,12 +217,14 @@ List buildClassPath(String appClassPath) throws IOException { libdir = new File(sparkHome, "lib_managed/jars"); } - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - for (File jar : libdir.listFiles()) { - if (jar.getName().startsWith("datanucleus-")) { - addToClassPath(cp, jar.getAbsolutePath()); + if (libdir.isDirectory()) { + for (File jar : libdir.listFiles()) { + if (jar.getName().startsWith("datanucleus-")) { + addToClassPath(cp, jar.getAbsolutePath()); + } } + } else { + checkState(isTesting, "Library directory '%s' does not exist.", libdir.getAbsolutePath()); } addToClassPath(cp, getenv("HADOOP_CONF_DIR")); @@ -256,15 +260,15 @@ String getScalaVersion() { return scala; } String sparkHome = getSparkHome(); - File scala210 = new File(sparkHome, "assembly/target/scala-2.10"); - File scala211 = new File(sparkHome, "assembly/target/scala-2.11"); + File scala210 = new File(sparkHome, "launcher/target/scala-2.10"); + File scala211 = new File(sparkHome, "launcher/target/scala-2.11"); checkState(!scala210.isDirectory() || !scala211.isDirectory(), "Presence of build for both scala versions (2.10 and 2.11) detected.\n" + "Either clean one of them or set SPARK_SCALA_VERSION in your environment."); if (scala210.isDirectory()) { return "2.10"; } else { - checkState(scala211.isDirectory(), "Cannot find any assembly build directories."); + checkState(scala211.isDirectory(), "Cannot find any build directories."); return "2.11"; } } diff --git a/pom.xml b/pom.xml index 0716016523ee..88ebceca769e 100644 --- a/pom.xml +++ b/pom.xml @@ -1421,6 +1421,10 @@ org.apache.thrift libthrift + + org.mortbay.jetty + servlet-api + com.google.guava guava @@ -1892,6 +1896,8 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + 1 + 1 ${test.java.home} @@ -1929,6 +1935,8 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + 1 + 1 ${test.java.home} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ea52bfd67944..901cfa538d23 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -547,6 +547,8 @@ object TestSettings { envVars in Test ++= Map( "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), + "SPARK_PREPEND_CLASSES" -> "1", + "SPARK_TESTING" -> "1", "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", javaOptions in Test += "-Dspark.test.home=" + sparkHome, diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index b4f8049bff57..17c59ff06e0c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.yarn.server.MiniYARNCluster import org.scalatest.{BeforeAndAfterAll, Matchers} import org.apache.spark._ +import org.apache.spark.launcher.TestClasspathBuilder import org.apache.spark.util.Utils abstract class BaseYarnClusterSuite @@ -43,6 +44,9 @@ abstract class BaseYarnClusterSuite |log4j.appender.console.target=System.err |log4j.appender.console.layout=org.apache.log4j.PatternLayout |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + |log4j.logger.org.apache.hadoop=WARN + |log4j.logger.org.eclipse.jetty=WARN + |log4j.logger.org.spark-project.jetty=WARN """.stripMargin private var yarnCluster: MiniYARNCluster = _ @@ -51,8 +55,7 @@ abstract class BaseYarnClusterSuite private var hadoopConfDir: File = _ private var logConfDir: File = _ - - def yarnConfig: YarnConfiguration + def newYarnConfig(): YarnConfiguration override def beforeAll() { super.beforeAll() @@ -65,8 +68,14 @@ abstract class BaseYarnClusterSuite val logConfFile = new File(logConfDir, "log4j.properties") Files.write(LOG4J_CONF, logConfFile, UTF_8) + // Disable the disk utilization check to avoid the test hanging when people's disks are + // getting full. + val yarnConf = newYarnConfig() + yarnConf.set("yarn.nodemanager.disk-health-checker.max-disk-utilization-per-disk-percentage", + "100.0") + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) - yarnCluster.init(yarnConfig) + yarnCluster.init(yarnConf) yarnCluster.start() // There's a race in MiniYARNCluster in which start() may return before the RM has updated @@ -114,19 +123,23 @@ abstract class BaseYarnClusterSuite sparkArgs: Seq[String] = Nil, extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, - extraConf: Map[String, String] = Map()): Unit = { + extraConf: Map[String, String] = Map(), + extraEnv: Map[String, String] = Map()): Unit = { val master = if (clientMode) "yarn-client" else "yarn-cluster" val props = new Properties() props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) - val childClasspath = logConfDir.getAbsolutePath() + - File.pathSeparator + - sys.props("java.class.path") + - File.pathSeparator + - extraClassPath.mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", childClasspath) - props.setProperty("spark.executor.extraClassPath", childClasspath) + val testClasspath = new TestClasspathBuilder() + .buildClassPath( + logConfDir.getAbsolutePath() + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator)) + .asScala + .mkString(File.pathSeparator) + + props.setProperty("spark.driver.extraClassPath", testClasspath) + props.setProperty("spark.executor.extraClassPath", testClasspath) // SPARK-4267: make sure java options are propagated correctly. props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") @@ -168,7 +181,7 @@ abstract class BaseYarnClusterSuite appArgs Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv) } /** diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 5a4ea2ea2f4f..b5a42fd6afd9 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -28,7 +28,9 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers import org.apache.spark._ -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} +import org.apache.spark.launcher.TestClasspathBuilder +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, + SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils @@ -39,7 +41,7 @@ import org.apache.spark.util.Utils */ class YarnClusterSuite extends BaseYarnClusterSuite { - override def yarnConfig: YarnConfiguration = new YarnConfiguration() + override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() private val TEST_PYFILE = """ |import mod1, mod2 @@ -111,6 +113,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + // When running tests, let's not assume the user has built the assembly module, which also + // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the + // needed locations. + val sparkHome = sys.props("spark.test.home"); + val pythonPath = Seq( + s"$sparkHome/python/lib/py4j-0.8.2.1-src.zip", + s"$sparkHome/python") + val extraEnv = Map( + "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), + "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) + val moduleDir = if (clientMode) { // In client-mode, .py files added with --py-files are not visible in the driver. @@ -130,7 +143,8 @@ class YarnClusterSuite extends BaseYarnClusterSuite { runSpark(clientMode, primaryPyFile.getAbsolutePath(), sparkArgs = Seq("--py-files", pyFiles), - appArgs = Seq(result.getAbsolutePath())) + appArgs = Seq(result.getAbsolutePath()), + extraEnv = extraEnv) checkResult(result) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 5e8238822b90..8d9c9b3004ed 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} */ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { - override def yarnConfig: YarnConfiguration = { + override def newYarnConfig(): YarnConfiguration = { val yarnConfig = new YarnConfiguration() yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), diff --git a/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala b/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala new file mode 100644 index 000000000000..da9e8e21a26a --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.util.{List => JList, Map => JMap} + +/** + * Exposes AbstractCommandBuilder to the YARN tests, so that they can build classpaths the same + * way other cluster managers do. + */ +private[spark] class TestClasspathBuilder extends AbstractCommandBuilder { + + childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sys.props("spark.test.home")) + + override def buildClassPath(extraCp: String): JList[String] = super.buildClassPath(extraCp) + + /** Not used by the YARN tests. */ + override def buildCommand(env: JMap[String, String]): JList[String] = + throw new UnsupportedOperationException() + +} From 45723214e694b9a440723e9504c562e6393709f3 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Fri, 28 Aug 2015 13:09:13 -0700 Subject: [PATCH 0470/1824] [SPARK-10336][example] fix not being able to set intercept in LR example `fitIntercept` is a command line option but not set in the main program. dbtsai Author: Shuo Xiang Closes #8510 from coderxiang/intercept and squashes the following commits: 57c9b7d [Shuo Xiang] fix not being able to set intercept in LR example --- .../org/apache/spark/examples/ml/LogisticRegressionExample.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 7682557127b5..8e3760ddb50a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -136,6 +136,7 @@ object LogisticRegressionExample { .setElasticNetParam(params.elasticNetParam) .setMaxIter(params.maxIter) .setTol(params.tol) + .setFitIntercept(params.fitIntercept) stages += lor val pipeline = new Pipeline().setStages(stages.toArray) From 88032ecaf0455886aed7a66b30af80dae7f6cff7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 28 Aug 2015 13:53:31 -0700 Subject: [PATCH 0471/1824] [SPARK-9671] [MLLIB] re-org user guide and add migration guide This PR updates the MLlib user guide and adds migration guide for 1.4->1.5. * merge migration guide for `spark.mllib` and `spark.ml` packages * remove dependency section from `spark.ml` guide * move the paragraph about `spark.mllib` and `spark.ml` to the top and recommend `spark.ml` * move Sam's talk to footnote to make the section focus on dependencies Minor changes to code examples and other wording will be in a separate PR. jkbradley srowen feynmanliang Author: Xiangrui Meng Closes #8498 from mengxr/SPARK-9671. --- docs/ml-guide.md | 52 ++------------ docs/mllib-guide.md | 119 ++++++++++++++++----------------- docs/mllib-migration-guides.md | 30 +++++++++ 3 files changed, 95 insertions(+), 106 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 01bf5ee18e32..ce53400b6ee5 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -21,19 +21,11 @@ title: Spark ML Programming Guide \]` -Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of -high-level APIs that help users create and tune practical machine learning pipelines. - -*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. - -Note that we will keep supporting and adding features to `spark.mllib` along with the -development of `spark.ml`. -Users should be comfortable using `spark.mllib` features and expect more features coming. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute -to `spark.ml`. - -See the [Algorithm Guides section](#algorithm-guides) below for guides on sub-packages of `spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. - +The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of +[DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical +machine learning pipelines. +See the [Algorithm Guides section](#algorithm-guides) below for guides on sub-packages of +`spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. **Table of Contents** @@ -171,7 +163,7 @@ This is useful if there are two algorithms with the `maxIter` parameter in a `Pi # Algorithm Guides -There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. +There are now several algorithms in the Pipelines API which are not in the `spark.mllib` API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. **Pipelines API Algorithm Guides** @@ -880,35 +872,3 @@ jsc.stop(); - -# Dependencies - -Spark ML currently depends on MLlib and has the same dependencies. -Please see the [MLlib Dependencies guide](mllib-guide.html#dependencies) for more info. - -Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies. - -# Migration Guide - -## From 1.3 to 1.4 - -Several major API changes occurred, including: -* `Param` and other APIs for specifying parameters -* `uid` unique IDs for Pipeline components -* Reorganization of certain classes -Since the `spark.ml` API was an Alpha Component in Spark 1.3, we do not list all changes here. - -However, now that `spark.ml` is no longer an Alpha Component, we will provide details on any API changes for future releases. - -## From 1.2 to 1.3 - -The main API changes are from Spark SQL. We list the most important changes here: - -* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. -* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. -* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. - -Other changes were in `LogisticRegression`: - -* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). -* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 6330c977552d..876dcfd40ed7 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -5,21 +5,28 @@ displayTitle: Machine Learning Library (MLlib) Guide description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT --- -MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, -including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives. -Guides for individual algorithms are listed below. +MLlib is Spark's machine learning (ML) library. +Its goal is to make practical machine learning scalable and easy. +It consists of common learning algorithms and utilities, including classification, regression, +clustering, collaborative filtering, dimensionality reduction, as well as lower-level optimization +primitives and higher-level pipeline APIs. -The API is divided into 2 parts: +It divides into two packages: -* [The original `spark.mllib` API](mllib-guide.html#mllib-types-algorithms-and-utilities) is the primary API. -* [The "Pipelines" `spark.ml` API](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) is a higher-level API for constructing ML workflows. +* [`spark.mllib`](mllib-guide.html#mllib-types-algorithms-and-utilities) contains the original API + built on top of RDDs. +* [`spark.ml`](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) provides higher-level API + built on top of DataFrames for constructing ML pipelines. -We list major functionality from both below, with links to detailed guides. +Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. +But we will keep supporting `spark.mllib` along with the development of `spark.ml`. +Users should be comfortable using `spark.mllib` features and expect more features coming. +Developers should contribute new algorithms to `spark.ml` if they fit the ML pipeline concept well, +e.g., feature extractors and transformers. -# MLlib types, algorithms and utilities +We list major functionality from both below, with links to detailed guides. -This lists functionality included in `spark.mllib`, the main MLlib API. +# spark.mllib: data types, algorithms, and utilities * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) @@ -56,71 +63,63 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [limited-memory BFGS (L-BFGS)](mllib-optimization.html#limited-memory-bfgs-l-bfgs) * [PMML model export](mllib-pmml-model-export.html) -MLlib is under active development. -The APIs marked `Experimental`/`DeveloperApi` may change in future releases, -and the migration guide below will explain all changes between releases. - # spark.ml: high-level APIs for ML pipelines -Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of -high-level APIs that help users create and tune practical machine learning pipelines. - -*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. - -Note that we will keep supporting and adding features to `spark.mllib` along with the -development of `spark.ml`. -Users should be comfortable using `spark.mllib` features and expect more features coming. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute -to `spark.ml`. - -Guides for `spark.ml` include: +**[spark.ml programming guide](ml-guide.html)** provides an overview of the Pipelines API and major +concepts. It also contains sections on using algorithms within the Pipelines API, for example: -* **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts -* Guides on using algorithms within the Pipelines API: - * [Feature transformers](ml-features.html), including a few not in the lower-level `spark.mllib` API - * [Decision trees](ml-decision-tree.html) - * [Ensembles](ml-ensembles.html) - * [Linear methods](ml-linear-methods.html) +* [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Decision Trees for Classification and Regression](ml-decision-tree.html) +* [Ensembles](ml-ensembles.html) +* [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) # Dependencies -MLlib uses the linear algebra package -[Breeze](http://www.scalanlp.org/), which depends on -[netlib-java](https://github.com/fommil/netlib-java) for optimised -numerical processing. If natives are not available at runtime, you -will see a warning message and a pure JVM implementation will be used -instead. +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on +[netlib-java](https://github.com/fommil/netlib-java) for optimised numerical processing. +If natives libraries[^1] are not available at runtime, you will see a warning message and a pure JVM +implementation will be used instead. -To learn more about the benefits and background of system optimised -natives, you may wish to watch Sam Halliday's ScalaX talk on -[High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/)). +Due to licensing issues with runtime proprietary binaries, we do not include `netlib-java`'s native +proxies by default. +To configure `netlib-java` / Breeze to use system optimised binaries, include +`com.github.fommil.netlib:all:1.1.2` (or build Spark with `-Pnetlib-lgpl`) as a dependency of your +project and read the [netlib-java](https://github.com/fommil/netlib-java) documentation for your +platform's additional installation instructions. -Due to licensing issues with runtime proprietary binaries, we do not -include `netlib-java`'s native proxies by default. To configure -`netlib-java` / Breeze to use system optimised binaries, include -`com.github.fommil.netlib:all:1.1.2` (or build Spark with -`-Pnetlib-lgpl`) as a dependency of your project and read the -[netlib-java](https://github.com/fommil/netlib-java) documentation for -your platform's additional installation instructions. +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) -version 1.4 or newer. +[^1]: To learn more about the benefits and background of system optimised natives, you may wish to + watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). ---- +# Migration guide -# Migration Guide +MLlib is under active development. +The APIs marked `Experimental`/`DeveloperApi` may change in future releases, +and the migration guide below will explain all changes between releases. + +## From 1.4 to 1.5 -For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). +In the `spark.mllib` package, there are no break API changes but several behavior changes: -## From 1.3 to 1.4 +* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): + `RegressionMetrics.explainedVariance` returns the average regression sum of squares. +* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become + sorted. +* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default + convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. -In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: +In the `spark.ml` package, there exists one break API change and one behavior change: -* Gradient-Boosted Trees - * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. - * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. -* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. +* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed + from `Params.setDefault` due to a + [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). +* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is + added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. -## Previous Spark Versions +## Previous Spark versions Earlier migration guides are archived [on this page](mllib-migration-guides.html). + +--- diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 8df68d81f3c7..774b85d1f773 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -7,6 +7,25 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.3 to 1.4 + +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: + +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. + +In the `spark.ml` package, several major API changes occurred, including: + +* `Param` and other APIs for specifying parameters +* `uid` unique IDs for Pipeline components +* Reorganization of certain classes + +Since the `spark.ml` API was an alpha component in Spark 1.3, we do not list all changes here. +However, since 1.4 `spark.ml` is no longer an alpha component, we will provide details on any API +changes for future releases. + ## From 1.2 to 1.3 In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. @@ -23,6 +42,17 @@ In the `spark.mllib` package, there were several breaking changes. The first ch * In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. +In the `spark.ml` package, the main API changes are from Spark SQL. We list the most important changes here: + +* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. +* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. +* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. + +Other changes were in `LogisticRegression`: + +* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). +* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. + ## From 1.1 to 1.2 The only API changes in MLlib v1.2 are in From bb7f35239385ec74b5ee69631b5480fbcee253e4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 28 Aug 2015 14:38:20 -0700 Subject: [PATCH 0472/1824] [SPARK-10323] [SQL] fix nullability of In/InSet/ArrayContain After this PR, In/InSet/ArrayContain will return null if value is null, instead of false. They also will return null even if there is a null in the set/array. Author: Davies Liu Closes #8492 from davies/fix_in. --- .../expressions/collectionOperations.scala | 62 ++++++++-------- .../sql/catalyst/expressions/predicates.scala | 71 +++++++++++++++---- .../sql/catalyst/optimizer/Optimizer.scala | 6 -- .../CollectionFunctionsSuite.scala | 12 +++- .../catalyst/expressions/PredicateSuite.scala | 49 +++++++++---- .../optimizer/ConstantFoldingSuite.scala | 21 +----- .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++-- 7 files changed, 138 insertions(+), 97 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 646afa4047d8..7b8c5b723ded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{ - CodegenFallback, CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} import org.apache.spark.sql.types._ /** @@ -145,46 +143,42 @@ case class ArrayContains(left: Expression, right: Expression) } } - override def nullable: Boolean = false + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } - override def eval(input: InternalRow): Boolean = { - val arr = left.eval(input) - if (arr == null) { - false - } else { - val value = right.eval(input) - if (value == null) { - false - } else { - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) return true - ) - false + override def nullSafeEval(arr: Any, value: Any): Any = { + var hasNull = false + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (v == value) { + return true } + ) + if (hasNull) { + null + } else { + false } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arrGen = left.gen(ctx) - val elementGen = right.gen(ctx) - val i = ctx.freshName("i") - val getValue = ctx.getValue(arrGen.primitive, right.dataType, i) - s""" - ${arrGen.code} - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = false; - if (!${arrGen.isNull}) { - ${elementGen.code} - if (!${elementGen.isNull}) { - for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) { - if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) { - ${ev.primitive} = true; - break; - } - } + nullSafeCodeGen(ctx, ev, (arr, value) => { + val i = ctx.freshName("i") + val getValue = ctx.getValue(arr, right.dataType, i) + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.primitive} = true; + break; } } """ + }) } override def prettyName: String = "array_contains" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index fe7dffb81598..65706dba7d97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -103,6 +101,8 @@ case class Not(child: Expression) case class In(value: Expression, list: Seq[Expression]) extends Predicate with ImplicitCastInputTypes { + require(list != null, "list should not be null") + override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) override def checkInputDataTypes(): TypeCheckResult = { @@ -116,12 +116,31 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate override def children: Seq[Expression] = value +: list - override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) - list.exists(e => e.eval(input) == evaluatedValue) + if (evaluatedValue == null) { + null + } else { + var hasNull = false + list.foreach { e => + val v = e.eval(input) + if (v == evaluatedValue) { + return true + } else if (v == null) { + hasNull = true + } + } + if (hasNull) { + null + } else { + false + } + } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -131,7 +150,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate s""" if (!${ev.primitive}) { ${x.code} - if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) { + if (${x.isNull}) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) { + ${ev.isNull} = false; ${ev.primitive} = true; } } @@ -139,8 +161,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate s""" ${valueGen.code} boolean ${ev.primitive} = false; - boolean ${ev.isNull} = false; - $listCode + boolean ${ev.isNull} = ${valueGen.isNull}; + if (!${ev.isNull}) { + $listCode + } """ } } @@ -151,11 +175,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate */ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { - override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. + require(hset != null, "hset could not be null") + override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" - override def eval(input: InternalRow): Any = { - hset.contains(child.eval(input)) + @transient private[this] lazy val hasNull: Boolean = hset.contains(null) + + override def nullable: Boolean = child.nullable || hasNull + + protected override def nullSafeEval(value: Any): Any = { + if (hset.contains(value)) { + true + } else if (hasNull) { + null + } else { + false + } } def getHSet(): Set[Any] = hset @@ -166,12 +201,20 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with val childGen = child.gen(ctx) ctx.references += this val hsetTerm = ctx.freshName("hset") + val hasNullTerm = ctx.freshName("hasNull") ctx.addMutableState(setName, hsetTerm, s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();") + ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") s""" ${childGen.code} - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive}); + boolean ${ev.isNull} = ${childGen.isNull}; + boolean ${ev.primitive} = false; + if (!${ev.isNull}) { + ${ev.primitive} = $hsetTerm.contains(${childGen.primitive}); + if (!${ev.primitive} && $hasNullTerm) { + ${ev.isNull} = true; + } + } """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 854463dd11c7..a430000bef65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -395,12 +395,6 @@ object ConstantFolding extends Rule[LogicalPlan] { // Fold expressions that are foldable. case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) - - // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. - case In(Literal(v, _), list) if list.exists { - case Literal(candidate, _) if candidate == v => true - case _ => false - } => Literal.create(true, BooleanType) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 95f0e38212a1..a3e81888dfd0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -70,14 +70,20 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) - checkEvaluation(ArrayContains(a0, Literal(null)), false) + checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) checkEvaluation(ArrayContains(a1, Literal("")), true) - checkEvaluation(ArrayContains(a1, Literal(null)), false) + checkEvaluation(ArrayContains(a1, Literal("a")), null) + checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null) - checkEvaluation(ArrayContains(a2, Literal(null)), false) + checkEvaluation(ArrayContains(a2, Literal(1L)), null) + checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayContains(a3, Literal("")), null) + checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 54c04faddb47..03e7611fce8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.types._ @@ -119,6 +118,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("IN") { + checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal(1), Literal(2))), null) + checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal.create(null, IntegerType))), + null) + checkEvaluation(In(Literal(1), Seq(Literal.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal.create(null, IntegerType))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal.create(null, IntegerType))), null) checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) @@ -126,14 +131,18 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) + val ns = Literal.create(null, StringType) + checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(ns, Seq(ns)), null) + checkEvaluation(In(Literal("a"), Seq(ns)), null) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) primitiveTypes.map { t => - val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val dataGen = RandomDataGenerator.forType(t, nullable = true).get val inputData = Seq.fill(10) { val value = dataGen.apply() value match { @@ -142,9 +151,17 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { case _ => value } } - val input = inputData.map(Literal(_)) - checkEvaluation(In(input(0), input.slice(1, 10)), - inputData.slice(1, 10).contains(inputData(0))) + val input = inputData.map(Literal.create(_, t)) + val expected = if (inputData(0) == null) { + null + } else if (inputData.slice(1, 10).contains(inputData(0))) { + true + } else if (inputData.slice(1, 10).contains(null)) { + null + } else { + false + } + checkEvaluation(In(input(0), input.slice(1, 10)), expected) } } @@ -158,15 +175,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InSet(one, hS), true) checkEvaluation(InSet(two, hS), true) checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(nl, nS), true) checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), false) - checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + checkEvaluation(InSet(three, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) primitiveTypes.map { t => - val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val dataGen = RandomDataGenerator.forType(t, nullable = true).get val inputData = Seq.fill(10) { val value = dataGen.apply() value match { @@ -176,8 +193,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } val input = inputData.map(Literal(_)) - checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), - inputData.slice(1, 10).contains(inputData(0))) + val expected = if (inputData(0) == null) { + null + } else if (inputData.slice(1, 10).contains(inputData(0))) { + true + } else if (inputData.slice(1, 10).contains(null)) { + null + } else { + false + } + checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ec3b2f1edfa0..e67606288f51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -250,29 +250,14 @@ class ConstantFoldingSuite extends PlanTest { } test("Constant folding test: Fold In(v, list) into true or false") { - var originalQuery = + val originalQuery = testRelation .select('a) .where(In(Literal(1), Seq(Literal(1), Literal(2)))) - var optimized = Optimize.execute(originalQuery.analyze) - - var correctAnswer = - testRelation - .select('a) - .where(Literal(true)) - .analyze - - comparePlans(optimized, correctAnswer) - - originalQuery = - testRelation - .select('a) - .where(In(Literal(1), Seq(Literal(1), 'a.attr))) - - optimized = Optimize.execute(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - correctAnswer = + val correctAnswer = testRelation .select('a) .where(Literal(true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9d965258e389..3a3f19af1473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -366,10 +366,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) - checkAnswer( - df.select(array_contains(array(lit(2), lit(null)), 1)), - Seq(Row(false), Row(false)) - ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -382,15 +378,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(null, 1)") } - // In hive, if either argument has a matching type has a null value, return false, even if - // the first argument array contains a null and the second argument is null checkAnswer( - df.selectExpr("array_contains(array(array(1), null)[1], 1)"), - Seq(Row(false), Row(false)) + df.selectExpr("array_contains(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true)) ) checkAnswer( - df.selectExpr("array_contains(array(0, null), array(1, null)[1])"), - Seq(Row(false), Row(false)) + df.selectExpr("array_contains(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true)) ) } } From 2a4e00ca4d4e7a148b4ff8ce0ad1c6d517cee55f Mon Sep 17 00:00:00 2001 From: felixcheung Date: Fri, 28 Aug 2015 18:35:01 -0700 Subject: [PATCH 0473/1824] [SPARK-9803] [SPARKR] Add subset and transform + tests Add subset and transform Also reorganize `[` & `[[` to subset instead of select Note: for transform, transform is very similar to mutate. Spark doesn't seem to replace existing column with the name in mutate (ie. `mutate(df, age = df$age + 2)` - returned DataFrame has 2 columns with the same name 'age'), so therefore not doing that for now in transform. Though it is clearly stated it should replace column with matching name (should I open a JIRA for mutate/transform?) Author: felixcheung Closes #8503 from felixcheung/rsubset_transform. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 70 +++++++++++++++++++++++++------- R/pkg/R/generics.R | 10 ++++- R/pkg/inst/tests/test_sparkSQL.R | 20 ++++++++- 4 files changed, 85 insertions(+), 17 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5286c0198620..9d3963070643 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -69,9 +69,11 @@ exportMethods("arrange", "selectExpr", "show", "showDF", + "subset", "summarize", "summary", "take", + "transform", "unionAll", "unique", "unpersist", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 74de7c81e35a..8a00238b41d6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -987,7 +987,7 @@ setMethod("$<-", signature(x = "DataFrame"), setClassUnion("numericOrcharacter", c("numeric", "character")) -#' @rdname select +#' @rdname subset #' @name [[ setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"), function(x, i) { @@ -998,7 +998,7 @@ setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"), getColumn(x, i) }) -#' @rdname select +#' @rdname subset #' @name [ setMethod("[", signature(x = "DataFrame", i = "missing"), function(x, i, j, ...) { @@ -1012,7 +1012,7 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), select(x, j) }) -#' @rdname select +#' @rdname subset #' @name [ setMethod("[", signature(x = "DataFrame", i = "Column"), function(x, i, j, ...) { @@ -1020,12 +1020,43 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), # https://stat.ethz.ch/R-manual/R-devel/library/base/html/Extract.data.frame.html filtered <- filter(x, i) if (!missing(j)) { - filtered[, j] + filtered[, j, ...] } else { filtered } }) +#' Subset +#' +#' Return subsets of DataFrame according to given conditions +#' @param x A DataFrame +#' @param subset A logical expression to filter on rows +#' @param select expression for the single Column or a list of columns to select from the DataFrame +#' @return A new DataFrame containing only the rows that meet the condition with selected columns +#' @export +#' @rdname subset +#' @name subset +#' @aliases [ +#' @family subsetting functions +#' @examples +#' \dontrun{ +#' # Columns can be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' df[,c("name", "age")] +#' # Or to filter rows +#' df[df$age > 20,] +#' # DataFrame can be subset on both rows and Columns +#' df[df$name == "Smith", c(1,2)] +#' df[df$age %in% c(19, 30), 1:2] +#' subset(df, df$age %in% c(19, 30), 1:2) +#' subset(df, df$age %in% c(19), select = c(1,2)) +#' } +setMethod("subset", signature(x = "DataFrame"), + function(x, subset, select, ...) { + x[subset, select, ...] + }) + #' Select #' #' Selects a set of columns with names or Column expressions. @@ -1034,6 +1065,8 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' @return A new DataFrame with selected columns #' @export #' @rdname select +#' @name select +#' @family subsetting functions #' @examples #' \dontrun{ #' select(df, "*") @@ -1041,15 +1074,8 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' select(df, df$name, df$age + 1) #' select(df, c("col1", "col2")) #' select(df, list(df$name, df$age + 1)) -#' # Columns can also be selected using `[[` and `[` -#' df[[2]] == df[["age"]] -#' df[,2] == df[,"age"] -#' df[,c("name", "age")] #' # Similar to R data frames columns can also be selected using `$` #' df$age -#' # It can also be subset on rows and Columns -#' df[df$name == "Smith", c(1,2)] -#' df[df$age %in% c(19, 30), 1:2] #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { @@ -1121,7 +1147,7 @@ setMethod("selectExpr", #' @return A DataFrame with the new column added. #' @rdname withColumn #' @name withColumn -#' @aliases mutate +#' @aliases mutate transform #' @export #' @examples #'\dontrun{ @@ -1141,11 +1167,12 @@ setMethod("withColumn", #' #' Return a new DataFrame with the specified columns added. #' -#' @param x A DataFrame +#' @param .data A DataFrame #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. #' @rdname withColumn #' @name mutate +#' @aliases withColumn transform #' @export #' @examples #'\dontrun{ @@ -1155,10 +1182,12 @@ setMethod("withColumn", #' df <- jsonFile(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 +#' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) #' } setMethod("mutate", - signature(x = "DataFrame"), - function(x, ...) { + signature(.data = "DataFrame"), + function(.data, ...) { + x <- .data cols <- list(...) stopifnot(length(cols) > 0) stopifnot(class(cols[[1]]) == "Column") @@ -1173,6 +1202,16 @@ setMethod("mutate", do.call(select, c(x, x$"*", cols)) }) +#' @export +#' @rdname withColumn +#' @name transform +#' @aliases withColumn mutate +setMethod("transform", + signature(`_data` = "DataFrame"), + function(`_data`, ...) { + mutate(`_data`, ...) + }) + #' WithColumnRenamed #' #' Rename an existing column in a DataFrame. @@ -1300,6 +1339,7 @@ setMethod("orderBy", #' @return A DataFrame containing only the rows that meet the condition. #' @rdname filter #' @name filter +#' @family subsetting functions #' @export #' @examples #'\dontrun{ diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b578b8789d2c..43dd8d283ab6 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -467,7 +467,7 @@ setGeneric("merge") #' @rdname withColumn #' @export -setGeneric("mutate", function(x, ...) {standardGeneric("mutate") }) +setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) #' @rdname arrange #' @export @@ -507,6 +507,10 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { standardGeneric("saveAsTable") }) +#' @rdname withColumn +#' @export +setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) @@ -531,6 +535,10 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @export setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) +# @rdname subset +# @export +setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") }) + #' @rdname agg #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 933b11c8ee7e..0da5e3865473 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -612,6 +612,10 @@ test_that("subsetting", { df5 <- df[df$age %in% c(19), c(1,2)] expect_equal(count(df5), 1) expect_equal(columns(df5), c("name", "age")) + + df6 <- subset(df, df$age %in% c(30), c(1,2)) + expect_equal(count(df6), 1) + expect_equal(columns(df6), c("name", "age")) }) test_that("selectExpr() on a DataFrame", { @@ -1028,7 +1032,7 @@ test_that("withColumn() and withColumnRenamed()", { expect_equal(columns(newDF2)[1], "newerAge") }) -test_that("mutate(), rename() and names()", { +test_that("mutate(), transform(), rename() and names()", { df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_equal(length(columns(newDF)), 3) @@ -1042,6 +1046,20 @@ test_that("mutate(), rename() and names()", { names(newDF2) <- c("newerName", "evenNewerAge") expect_equal(length(names(newDF2)), 2) expect_equal(names(newDF2)[1], "newerName") + + transformedDF <- transform(df, newAge = -df$age, newAge2 = df$age / 2) + expect_equal(length(columns(transformedDF)), 4) + expect_equal(columns(transformedDF)[3], "newAge") + expect_equal(columns(transformedDF)[4], "newAge2") + expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30) + + # test if transform on local data frames works + # ensure the proper signature is used - otherwise this will fail to run + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + expect_equal(nrow(result), 153) + expect_equal(ncol(result), 2) + detach(airquality) }) test_that("write.df() on DataFrame and works with parquetFile", { From e8ea5bafee9ca734edf62021145d0c2d5491cba8 Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Fri, 28 Aug 2015 21:03:48 -0700 Subject: [PATCH 0474/1824] [SPARK-9910] [ML] User guide for train validation split Author: martinzapletal Closes #8377 from zapletal-martin/SPARK-9910. --- docs/ml-guide.md | 117 ++++++++++++++++++ .../ml/JavaTrainValidationSplitExample.java | 90 ++++++++++++++ .../ml/TrainValidationSplitExample.scala | 80 ++++++++++++ 3 files changed, 287 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala diff --git a/docs/ml-guide.md b/docs/ml-guide.md index ce53400b6ee5..a92a285f3af8 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -872,3 +872,120 @@ jsc.stop(); + +## Example: Model Selection via Train Validation Split +In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. +`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in + case of `CrossValidator`. It is therefore less expensive, + but will not produce as reliable results when the training dataset is not sufficiently large.. + +`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, +and an `Evaluator`. +It begins by splitting the dataset into two parts using `trainRatio` parameter +which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), +`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. +Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. +For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. +The `ParamMap` which produces the best evaluation metric is selected as the best option. +`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +import org.apache.spark.mllib.util.MLUtils + +// Prepare training and test data. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + +val lr = new LinearRegression() + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept, Array(true, false)) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + +// 80% of the data will be used for training and the remaining 20% for validation. +trainValidationSplit.setTrainRatio(0.8) + +// Run train validation split, and choose the best set of parameters. +val model = trainValidationSplit.fit(training) + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show() + +{% endhighlight %} +
    + +
    +{% highlight java %} +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +DataFrame data = jsql.createDataFrame( + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), + LabeledPoint.class); + +// Prepare training and test data. +DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); +DataFrame training = splits[0]; +DataFrame test = splits[1]; + +LinearRegression lr = new LinearRegression(); + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid); + +// 80% of the data will be used for training and the remaining 20% for validation. +trainValidationSplit.setTrainRatio(0.8); + +// Run train validation split, and choose the best set of parameters. +TrainValidationSplitModel model = trainValidationSplit.fit(training); + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show(); + +{% endhighlight %} +
    + +
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java new file mode 100644 index 000000000000..23f834ab4332 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +/** + * A simple example demonstrating model selection using TrainValidationSplit. + * + * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} + * using linear regression. + * + * Run with + * {{{ + * bin/run-example ml.JavaTrainValidationSplitExample + * }}} + */ +public class JavaTrainValidationSplitExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + DataFrame data = jsql.createDataFrame( + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), + LabeledPoint.class); + + // Prepare training and test data. + DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); + DataFrame training = splits[0]; + DataFrame test = splits[1]; + + LinearRegression lr = new LinearRegression(); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid); + + // 80% of the data will be used for training and the remaining 20% for validation. + trainValidationSplit.setTrainRatio(0.8); + + // Run train validation split, and choose the best set of parameters. + TrainValidationSplitModel model = trainValidationSplit.fit(training); + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show(); + + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala new file mode 100644 index 000000000000..1abdf219b1c0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +/** + * A simple example demonstrating model selection using TrainValidationSplit. + * + * The example is based on [[SimpleParamsExample]] using linear regression. + * Run with + * {{{ + * bin/run-example ml.TrainValidationSplitExample + * }}} + */ +object TrainValidationSplitExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TrainValidationSplitExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Prepare training and test data. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + + val lr = new LinearRegression() + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept, Array(true, false)) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + + // 80% of the data will be used for training and the remaining 20% for validation. + trainValidationSplit.setTrainRatio(0.8) + + // Run train validation split, and choose the best set of parameters. + val model = trainValidationSplit.fit(training) + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show() + + sc.stop() + } +} From 5369be806848f43cb87c76504258c4e7de930c90 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 29 Aug 2015 13:20:22 -0700 Subject: [PATCH 0475/1824] [SPARK-10350] [DOC] [SQL] Removed duplicated option description from SQL guide Author: GuoQiang Li Closes #8520 from witgo/SPARK-10350. --- docs/sql-programming-guide.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e8eb88488ee2..6a1b0fbfa1eb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1405,16 +1405,6 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

    - - - - -
    Property NameDefaultMeaning
    spark.akka.failure-detector.threshold300.0 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). This maps to akka's - `akka.remote.transport-failure-detector.threshold`. Tune this in combination of - `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to. -
    spark.akka.frameSize 128
    spark.sql.parquet.mergeSchemafalse -

    - When true, the Parquet data source merges schemas collected from all data files, otherwise the - schema is picked from the summary file or a random data file if no summary file is available. -

    -
    ## JSON Datasets From 24ffa85c002a095ffb270175ec838995d3ed5469 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 29 Aug 2015 13:24:32 -0700 Subject: [PATCH 0476/1824] [SPARK-10289] [SQL] A direct write API for testing Parquet This PR introduces a direct write API for testing Parquet. It's a DSL flavored version of the [`writeDirect` method] [1] comes with parquet-avro testing code. With this API, it's much easier to construct arbitrary Parquet structures. It's especially useful when adding regression tests for various compatibility corner cases. Sample usage of this API can be found in the new test case added in `ParquetThriftCompatibilitySuite`. [1]: https://github.com/apache/parquet-mr/blob/apache-parquet-1.8.1/parquet-avro/src/test/java/org/apache/parquet/avro/TestArrayCompatibility.java#L945-L972 Author: Cheng Lian Closes #8454 from liancheng/spark-10289/parquet-testing-direct-write-api. --- .../parquet/ParquetCompatibilityTest.scala | 84 +++++++++++++-- .../ParquetThriftCompatibilitySuite.scala | 100 +++++++++++++++--- 2 files changed, 160 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index df68432faeeb..91f3ce4d34c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} -import org.apache.parquet.hadoop.ParquetFileReader -import org.apache.parquet.schema.MessageType +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.QueryTest @@ -38,11 +42,10 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq val fs = fsPath.getFileSystem(configuration) val parquetFiles = fs.listStatus(fsPath, new PathFilter { override def accept(path: Path): Boolean = pathFilter(path) - }).toSeq + }).toSeq.asJava - val footers = - ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true) - footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema + val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) + footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema } protected def logParquetSchema(path: String): Unit = { @@ -53,8 +56,69 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq } } -object ParquetCompatibilityTest { - def makeNullable[T <: AnyRef](i: Int)(f: => T): T = { - if (i % 3 == 0) null.asInstanceOf[T] else f +private[sql] object ParquetCompatibilityTest { + implicit class RecordConsumerDSL(consumer: RecordConsumer) { + def message(f: => Unit): Unit = { + consumer.startMessage() + f + consumer.endMessage() + } + + def group(f: => Unit): Unit = { + consumer.startGroup() + f + consumer.endGroup() + } + + def field(name: String, index: Int)(f: => Unit): Unit = { + consumer.startField(name, index) + f + consumer.endField(name, index) + } + } + + /** + * A testing Parquet [[WriteSupport]] implementation used to write manually constructed Parquet + * records with arbitrary structures. + */ + private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String]) + extends WriteSupport[RecordConsumer => Unit] { + + private var recordConsumer: RecordConsumer = _ + + override def init(configuration: Configuration): WriteContext = { + new WriteContext(schema, metadata.asJava) + } + + override def write(recordWriter: RecordConsumer => Unit): Unit = { + recordWriter.apply(recordConsumer) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + this.recordConsumer = recordConsumer + } + } + + /** + * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`. + * Records are produced by `recordWriters`. + */ + def writeDirect(path: String, schema: String, recordWriters: (RecordConsumer => Unit)*): Unit = { + writeDirect(path, schema, Map.empty[String, String], recordWriters: _*) + } + + /** + * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path` + * with given user-defined key-value `metadata`. Records are produced by `recordWriters`. + */ + def writeDirect( + path: String, + schema: String, + metadata: Map[String, String], + recordWriters: (RecordConsumer => Unit)*): Unit = { + val messageType = MessageTypeParser.parseMessageType(schema) + val writeSupport = new DirectWriteSupport(messageType, metadata) + val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport) + try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index b789c5a106e5..88a3d878f97f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -33,11 +33,9 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar """.stripMargin) checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) - val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") - Row( + val nonNullablePrimitiveValues = Seq( i % 2 == 0, i.toByte, (i + 1).toShort, @@ -50,18 +48,15 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar s"val_$i", s"val_$i", // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings - suits(i % 4), - - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i.toByte: java.lang.Byte), - nullable((i + 1).toShort: java.lang.Short), - nullable(i + 2: Integer), - nullable((i * 10).toLong: java.lang.Long), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(s"val_$i"), - nullable(s"val_$i"), - nullable(suits(i % 4)), + suits(i % 4)) + + val nullablePrimitiveValues = if (i % 3 == 0) { + Seq.fill(nonNullablePrimitiveValues.length)(null) + } else { + nonNullablePrimitiveValues + } + val complexValues = Seq( Seq.tabulate(3)(n => s"arr_${i + n}"), // Thrift `SET`s are converted to Parquet `LIST`s Seq(i), @@ -71,6 +66,83 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") } }.toMap) + + Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*) }) } + + test("SPARK-10136 list of primitive list") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // This Parquet schema is translated from the following Thrift schema: + // + // struct ListOfPrimitiveList { + // 1: list> f; + // } + val schema = + s"""message ListOfPrimitiveList { + | required group f (LIST) { + | repeated group f_tuple (LIST) { + | repeated int32 f_tuple_tuple; + | } + | } + |} + """.stripMargin + + writeDirect(path, schema, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("f_tuple", 0) { + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(0) + rc.addInteger(1) + } + } + + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(2) + rc.addInteger(3) + } + } + } + } + } + } + }, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("f_tuple", 0) { + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(4) + rc.addInteger(5) + } + } + + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(6) + rc.addInteger(7) + } + } + } + } + } + } + }) + + logParquetSchema(path) + + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(Seq(Seq(0, 1), Seq(2, 3))), + Row(Seq(Seq(4, 5), Seq(6, 7))))) + } + } } From 5c3d16a9b91bb9a458d3ba141f7bef525cf3d285 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 29 Aug 2015 13:26:01 -0700 Subject: [PATCH 0477/1824] [SPARK-10344] [SQL] Add tests for extraStrategies Actually using this API requires access to a lot of classes that we might make private by accident. I've added some tests to prevent this. Author: Michael Armbrust Closes #8516 from marmbrus/extraStrategiesTests. --- .../spark/sql/ExtraStrategiesSuite.scala | 67 +++++++++++++++++++ .../spark/sql/test/SharedSQLContext.scala | 2 +- 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala new file mode 100644 index 000000000000..8d2f45d70308 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.{Row, Strategy, QueryTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.UTF8String + +case class FastOperator(output: Seq[Attribute]) extends SparkPlan { + + override protected def doExecute(): RDD[InternalRow] = { + val str = Literal("so fast").value + val row = new GenericInternalRow(Array[Any](str)) + sparkContext.parallelize(Seq(row)) + } + + override def children: Seq[SparkPlan] = Nil +} + +object TestStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case Project(Seq(attr), _) if attr.name == "a" => + FastOperator(attr.toAttribute :: Nil) :: Nil + case _ => Nil + } +} + +class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("insert an extraStrategy") { + try { + sqlContext.experimental.extraStrategies = TestStrategy :: Nil + + val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") + checkAnswer( + df.select("a"), + Row("so fast")) + + checkAnswer( + df.select("a", "b"), + Row("so slow", 1)) + } finally { + sqlContext.experimental.extraStrategies = Nil + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 8a061b6bc690..d23c6a073266 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.{ColumnName, SQLContext} /** * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. */ -private[sql] trait SharedSQLContext extends SQLTestUtils { +trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. From 277148b285748e863f2b9fdf6cf12963977f91ca Mon Sep 17 00:00:00 2001 From: wangwei Date: Sat, 29 Aug 2015 13:29:50 -0700 Subject: [PATCH 0478/1824] [SPARK-10226] [SQL] Fix exclamation mark issue in SparkSQL When I tested the latest version of spark with exclamation mark, I got some errors. Then I reseted the spark version and found that commit id "a2409d1c8e8ddec04b529ac6f6a12b5993f0eeda" brought the bug. With jline version changing from 0.9.94 to 2.12 after this commit, exclamation mark would be treated as a special character in ConsoleReader. Author: wangwei Closes #8420 from small-wang/jline-SPARK-10226. --- .../apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index a29df567983b..b5073961a1c8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -171,6 +171,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { val reader = new ConsoleReader() reader.setBellEnabled(false) + reader.setExpandEvents(false) // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) CliDriver.getCommandCompleter.foreach((e) => reader.addCompleter(e)) From 6a6f3c91ee1f63dd464eb03d156d02c1a5887d88 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Aug 2015 13:36:25 -0700 Subject: [PATCH 0479/1824] [SPARK-10330] Use SparkHadoopUtil TaskAttemptContext reflection methods in more places SparkHadoopUtil contains methods that use reflection to work around TaskAttemptContext binary incompatibilities between Hadoop 1.x and 2.x. We should use these methods in more places. Author: Josh Rosen Closes #8499 from JoshRosen/use-hadoop-reflection-in-more-places. --- .../sql/execution/datasources/WriterContainer.scala | 10 +++++++--- .../sql/execution/datasources/json/JSONRelation.scala | 7 +++++-- .../datasources/parquet/ParquetRelation.scala | 7 +++++-- .../org/apache/spark/sql/hive/orc/OrcRelation.scala | 9 ++++++--- .../apache/spark/sql/sources/SimpleTextRelation.scala | 7 +++++-- 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 78f48a5cd72c..879fd6986321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql._ @@ -145,7 +146,8 @@ private[sql] abstract class BaseWriterContainer( "because spark.speculation is configured to be true.") defaultOutputCommitter } else { - val committerClass = context.getConfiguration.getClass( + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val committerClass = configuration.getClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) Option(committerClass).map { clazz => @@ -227,7 +229,8 @@ private[sql] class DefaultWriterContainer( def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { executorSideSetup(taskContext) - taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + configuration.set("spark.sql.sources.output.path", outputPath) val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) writer.initConverter(dataSchema) @@ -395,7 +398,8 @@ private[sql] class DynamicPartitionWriterContainer( def newOutputWriter(key: InternalRow): OutputWriter = { val partitionPath = getPartitionString(key).getString(0) val path = new Path(getWorkPath, partitionPath) - taskAttemptContext.getConfiguration.set( + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + configuration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) newWriter.initConverter(dataSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index ab8ca5f748f2..7a49157d9e72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -169,8 +170,10 @@ private[json] class JsonOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } }.getRecordWriter(context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 64982f37cf87..c6bbc392cad4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -40,6 +40,7 @@ import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -81,8 +82,10 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 1cff5cf9c354..4eeca9aec12b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow @@ -77,7 +78,8 @@ private[orc] class OrcOutputWriter( }.mkString(":")) val serde = new OrcSerde - serde.initialize(context.getConfiguration, table) + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + serde.initialize(configuration, table) serde } @@ -109,9 +111,10 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val conf = context.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(context) val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") - val partition = context.getTaskAttemptID.getTaskID.getId + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val partition = taskAttemptId.getTaskID.getId val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" new OrcOutputFormat().getRecordWriter( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index e8141923a9b5..527ca7a81cad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputForma import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types.{DataType, StructType} @@ -53,8 +54,10 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") } From 097a7e36e0bf7290b1879331375bacc905583bd3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 29 Aug 2015 16:39:40 -0700 Subject: [PATCH 0480/1824] [SPARK-10339] [SPARK-10334] [SPARK-10301] [SQL] Partitioned table scan can OOM driver and throw a better error message when users need to enable parquet schema merging This fixes the problem that scanning partitioned table causes driver have a high memory pressure and takes down the cluster. Also, with this fix, we will be able to correctly show the query plan of a query consuming partitioned tables. https://issues.apache.org/jira/browse/SPARK-10339 https://issues.apache.org/jira/browse/SPARK-10334 Finally, this PR squeeze in a "quick fix" for SPARK-10301. It is not a real fix, but it just throw a better error message to let user know what to do. Author: Yin Huai Closes #8515 from yhuai/partitionedTableScan. --- .../datasources/DataSourceStrategy.scala | 85 ++++++++++--------- .../parquet/CatalystRowConverter.scala | 7 ++ .../ParquetHadoopFsRelationSuite.scala | 15 +++- 3 files changed, 65 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6c1ef6a6df88..c58213155daa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} @@ -121,7 +122,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { projections: Seq[NamedExpression], filters: Seq[Expression], partitionColumns: StructType, - partitions: Array[Partition]) = { + partitions: Array[Partition]): SparkPlan = { val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] // Because we are creating one RDD per partition, we need to have a shared HadoopConf. @@ -130,49 +131,51 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val confBroadcast = relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) - // Builds RDD[Row]s for each selected partition. - val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => - // The table scan operator (PhysicalRDD) which retrieves required columns from data files. - // Notice that the schema of data files, represented by `relation.dataSchema`, may contain - // some partition column(s). - val scan = - pruneFilterProject( - logicalRelation, - projections, - filters, - (columns: Seq[Attribute], filters) => { - val partitionColNames = partitionColumns.fieldNames - - // Don't scan any partition columns to save I/O. Here we are being optimistic and - // assuming partition columns data stored in data files are always consistent with those - // partition values encoded in partition directory paths. - val needed = columns.filterNot(a => partitionColNames.contains(a.name)) - val dataRows = - relation.buildScan(needed.map(_.name).toArray, filters, Array(dir), confBroadcast) - - // Merges data values with partition values. - mergeWithPartitionValues( - relation.schema, - columns.map(_.name).toArray, - partitionColNames, - partitionValues, - toCatalystRDD(logicalRelation, needed, dataRows)) - }) - - scan.execute() - } + // Now, we create a scan builder, which will be used by pruneFilterProject. This scan builder + // will union all partitions and attach partition values if needed. + val scanBuilder = { + (columns: Seq[Attribute], filters: Array[Filter]) => { + // Builds RDD[Row]s for each selected partition. + val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => + val partitionColNames = partitionColumns.fieldNames + + // Don't scan any partition columns to save I/O. Here we are being optimistic and + // assuming partition columns data stored in data files are always consistent with those + // partition values encoded in partition directory paths. + val needed = columns.filterNot(a => partitionColNames.contains(a.name)) + val dataRows = + relation.buildScan(needed.map(_.name).toArray, filters, Array(dir), confBroadcast) + + // Merges data values with partition values. + mergeWithPartitionValues( + relation.schema, + columns.map(_.name).toArray, + partitionColNames, + partitionValues, + toCatalystRDD(logicalRelation, needed, dataRows)) + } + + val unionedRows = + if (perPartitionRows.length == 0) { + relation.sqlContext.emptyResult + } else { + new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) + } - val unionedRows = - if (perPartitionRows.length == 0) { - relation.sqlContext.emptyResult - } else { - new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) + unionedRows } + } + + // Create the scan operator. If needed, add Filter and/or Project on top of the scan. + // The added Filter/Project is on top of the unioned RDD. We do not want to create + // one Filter/Project for every partition. + val sparkPlan = pruneFilterProject( + logicalRelation, + projections, + filters, + scanBuilder) - execution.PhysicalRDD.createFromDataSource( - projections.map(_.toAttribute), - unionedRows, - logicalRelation.relation) + sparkPlan } // TODO: refactor this thing. It is very complicated because it does projection internally. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index f682ca0d8ff4..fe13dfbbed38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -196,6 +196,13 @@ private[parquet] class CatalystRowConverter( } } + if (paddedParquetFields.length != catalystType.length) { + throw new UnsupportedOperationException( + "A Parquet file's schema has different number of fields with the table schema. " + + "Please enable schema merging by setting \"mergeSchema\" to true when load " + + "a Parquet dataset or set spark.sql.parquet.mergeSchema to true in SQLConf.") + } + paddedParquetFields.zip(catalystType).zipWithIndex.map { case ((parquetFieldType, catalystField), ordinal) => // Converted field value should be set to the `ordinal`-th cell of `currentRow` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index cb4cedddbfdd..06dadbb5feab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,7 +23,7 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.{execution, AnalysisException, SaveMode} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -136,4 +136,17 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { assert(fs.exists(commonSummaryPath)) } } + + test("SPARK-10334 Projections and filters should be kept in physical plan") { + withTempPath { dir => + val path = dir.getCanonicalPath + + sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) + val df = sqlContext.read.parquet(path).filter('a === 0).select('b) + val physicalPlan = df.queryExecution.executedPlan + + assert(physicalPlan.collect { case p: execution.Project => p }.length === 1) + assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) + } + } } From 13f5f8ec97c6886346641b73bd99004e0d70836c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 29 Aug 2015 18:10:44 -0700 Subject: [PATCH 0481/1824] [SPARK-9986] [SPARK-9991] [SPARK-9993] [SQL] Create a simple test framework for local operators This PR includes the following changes: - Add `LocalNodeTest` for local operator tests and add unit tests for FilterNode and ProjectNode. - Add `LimitNode` and `UnionNode` and their unit tests to show how to use `LocalNodeTest`. (SPARK-9991, SPARK-9993) Author: zsxwing Closes #8464 from zsxwing/local-execution. --- .../sql/execution/local/FilterNode.scala | 6 +- .../spark/sql/execution/local/LimitNode.scala | 45 ++++++ .../spark/sql/execution/local/LocalNode.scala | 13 +- .../sql/execution/local/ProjectNode.scala | 4 +- .../sql/execution/local/SeqScanNode.scala | 2 +- .../spark/sql/execution/local/UnionNode.scala | 72 +++++++++ .../spark/sql/execution/SparkPlanTest.scala | 46 +----- .../sql/execution/local/FilterNodeSuite.scala | 41 +++++ .../sql/execution/local/LimitNodeSuite.scala | 39 +++++ .../sql/execution/local/LocalNodeTest.scala | 146 ++++++++++++++++++ .../execution/local/ProjectNodeSuite.scala | 44 ++++++ .../sql/execution/local/UnionNodeSuite.scala | 52 +++++++ .../apache/spark/sql/test/SQLTestData.scala | 8 + .../apache/spark/sql/test/SQLTestUtils.scala | 46 +++++- 14 files changed, 509 insertions(+), 55 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala index a485a1a1d7ae..81dd37c7da73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -35,13 +35,13 @@ case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLoca override def next(): Boolean = { var found = false - while (child.next() && !found) { - found = predicate.apply(child.get()) + while (!found && child.next()) { + found = predicate.apply(child.fetch()) } found } - override def get(): InternalRow = child.get() + override def fetch(): InternalRow = child.fetch() override def close(): Unit = child.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala new file mode 100644 index 000000000000..fffc52abf6dd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -0,0 +1,45 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + + +case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode { + + private[this] var count = 0 + + override def output: Seq[Attribute] = child.output + + override def open(): Unit = child.open() + + override def close(): Unit = child.close() + + override def fetch(): InternalRow = child.fetch() + + override def next(): Boolean = { + if (count < limit) { + count += 1 + child.next() + } else { + false + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 341c81438e6d..1c4469acbf26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -48,10 +48,10 @@ abstract class LocalNode extends TreeNode[LocalNode] { /** * Returns the current tuple. */ - def get(): InternalRow + def fetch(): InternalRow /** - * Closes the iterator and releases all resources. + * Closes the iterator and releases all resources. It should be idempotent. * * Implementations of this must also call the `close()` function of its children. */ @@ -64,10 +64,13 @@ abstract class LocalNode extends TreeNode[LocalNode] { val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) val result = new scala.collection.mutable.ArrayBuffer[Row] open() - while (next()) { - result += converter.apply(get()).asInstanceOf[Row] + try { + while (next()) { + result += converter.apply(fetch()).asInstanceOf[Row] + } + } finally { + close() } - close() result } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index e574d1473cdc..9b8a4fe49302 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -34,8 +34,8 @@ case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) exte override def next(): Boolean = child.next() - override def get(): InternalRow = { - project.apply(child.get()) + override def fetch(): InternalRow = { + project.apply(child.fetch()) } override def close(): Unit = child.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala index 994de8afa9a0..242cb66e07b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -41,7 +41,7 @@ case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends L } } - override def get(): InternalRow = currentRow + override def fetch(): InternalRow = currentRow override def close(): Unit = { // Do nothing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala new file mode 100644 index 000000000000..ba4aa7671aeb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class UnionNode(children: Seq[LocalNode]) extends LocalNode { + + override def output: Seq[Attribute] = children.head.output + + private[this] var currentChild: LocalNode = _ + + private[this] var nextChildIndex: Int = _ + + override def open(): Unit = { + currentChild = children.head + currentChild.open() + nextChildIndex = 1 + } + + private def advanceToNextChild(): Boolean = { + var found = false + var exit = false + while (!exit && !found) { + if (currentChild != null) { + currentChild.close() + } + if (nextChildIndex >= children.size) { + found = false + exit = true + } else { + currentChild = children(nextChildIndex) + nextChildIndex += 1 + currentChild.open() + found = currentChild.next() + } + } + found + } + + override def close(): Unit = { + if (currentChild != null) { + currentChild.close() + } + } + + override def fetch(): InternalRow = currentChild.fetch() + + override def next(): Boolean = { + if (currentChild.next()) { + true + } else { + advanceToNextChild() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 3a87f374d94b..5ab8f44faebf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test.SQLTestUtils /** * Base class for writing tests for individual physical operators. For an example of how this @@ -184,7 +184,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => + SQLTestUtils.compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match. | Actual result Spark plan: @@ -229,7 +229,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => + SQLTestUtils.compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match for Spark plan: | $outputPlan @@ -238,46 +238,6 @@ object SparkPlanTest { } } - private def compareAnswers( - sparkAnswer: Seq[Row], - expectedAnswer: Seq[Row], - sort: Boolean): Option[String] = { - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - // This function is copied from Catalyst's QueryTest - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } - if (sort) { - converted.sortBy(_.toString()) - } else { - converted - } - } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - | == Results == - | ${sideBySide( - s"== Expected Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Actual Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - Some(errorMessage) - } else { - None - } - } - private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala new file mode 100644 index 000000000000..07209f377924 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -0,0 +1,41 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + val condition = (testData.col("key") % 2) === 0 + checkAnswer( + testData, + node => FilterNode(condition.expr, node), + testData.filter(condition).collect() + ) + } + + test("empty") { + val condition = (emptyTestData.col("key") % 2) === 0 + checkAnswer( + emptyTestData, + node => FilterNode(condition.expr, node), + emptyTestData.filter(condition).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala new file mode 100644 index 000000000000..523c02f4a601 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -0,0 +1,39 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + checkAnswer( + testData, + node => LimitNode(10, node), + testData.limit(10).collect() + ) + } + + test("empty") { + checkAnswer( + emptyTestData, + node => LimitNode(10, node), + emptyTestData.limit(10).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala new file mode 100644 index 000000000000..95f06081bd0a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -0,0 +1,146 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.util.control.NonFatal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SQLTestUtils + +class LocalNodeTest extends SparkFunSuite { + + /** + * Runs the LocalNode and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate + * the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkAnswer( + input: DataFrame, + nodeFunction: LocalNode => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + input :: Nil, + nodes => nodeFunction(nodes.head), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the LocalNode and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate + * the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkAnswer2( + left: DataFrame, + right: DataFrame, + nodeFunction: (LocalNode, LocalNode) => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + left :: right :: Nil, + nodes => nodeFunction(nodes(0), nodes(1)), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the `LocalNode`s and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to + * instantiate the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def doCheckAnswer( + input: Seq[DataFrame], + nodeFunction: Seq[LocalNode] => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + LocalNodeTest.checkAnswer( + input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { + new SeqScanNode( + df.queryExecution.sparkPlan.output, + df.queryExecution.toRdd.map(_.copy()).collect()) + } + +} + +/** + * Helper methods for writing tests of individual local physical operators. + */ +object LocalNodeTest { + + /** + * Runs the `LocalNode`s and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to + * instantiate the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + def checkAnswer( + input: Seq[SeqScanNode], + nodeFunction: Seq[LocalNode] => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Option[String] = { + + val outputNode = nodeFunction(input) + + val outputResult: Seq[Row] = try { + outputNode.collect() + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing local plan: + | $outputNode + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage => + s""" + | Results do not match for local plan: + | $outputNode + | $errorMessage + """.stripMargin + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala new file mode 100644 index 000000000000..ffcf092e2c66 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -0,0 +1,44 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + val output = testData.queryExecution.sparkPlan.output + val columns = Seq(output(1), output(0)) + checkAnswer( + testData, + node => ProjectNode(columns, node), + testData.select("value", "key").collect() + ) + } + + test("empty") { + val output = emptyTestData.queryExecution.sparkPlan.output + val columns = Seq(output(1), output(0)) + checkAnswer( + emptyTestData, + node => ProjectNode(columns, node), + emptyTestData.select("value", "key").collect() + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala new file mode 100644 index 000000000000..34670287c3e1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -0,0 +1,52 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + checkAnswer2( + testData, + testData, + (node1, node2) => UnionNode(Seq(node1, node2)), + testData.unionAll(testData).collect() + ) + } + + test("empty") { + checkAnswer2( + emptyTestData, + emptyTestData, + (node1, node2) => UnionNode(Seq(node1, node2)), + emptyTestData.unionAll(emptyTestData).collect() + ) + } + + test("complicated union") { + val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData, + emptyTestData, emptyTestData, testData, emptyTestData) + doCheckAnswer( + dfs, + nodes => UnionNode(nodes), + dfs.reduce(_.unionAll(_)).collect() + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 1374a97476ca..3fc02df954e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -36,6 +36,13 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. + protected lazy val emptyTestData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("emptyTestData") + df + } + protected lazy val testData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() @@ -240,6 +247,7 @@ private[sql] trait SQLTestData { self => */ def loadTestData(): Unit = { assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + emptyTestData testData testData2 testData3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index cdd691e03589..dc08306ad9cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,8 +27,9 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SQLImplicits} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.util.Utils /** @@ -179,3 +180,46 @@ private[sql] trait SQLTestUtils DataFrame(_sqlContext, plan) } } + +private[sql] object SQLTestUtils { + + def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } + } + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | == Results == + | ${sideBySide( + s"== Expected Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Actual Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + Some(errorMessage) + } else { + None + } + } +} From 905fbe498bdd29116468628e6a2a553c1fd57165 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 29 Aug 2015 23:26:23 -0700 Subject: [PATCH 0482/1824] [SPARK-10348] [MLLIB] updates ml-guide * replace `ML Dataset` by `DataFrame` to unify the abstraction * ML algorithms -> pipeline components to describe the main concept * remove Scala API doc links from the main guide * `Section Title` -> `Section tile` to be consistent with other section titles in MLlib guide * modified lines break at 100 chars or periods jkbradley feynmanliang Author: Xiangrui Meng Closes #8517 from mengxr/SPARK-10348. --- docs/ml-guide.md | 118 +++++++++++++++++++++++++++----------------- docs/mllib-guide.md | 12 ++--- 2 files changed, 78 insertions(+), 52 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index a92a285f3af8..4ba07542bfb4 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -24,61 +24,74 @@ title: Spark ML Programming Guide The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of [DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical machine learning pipelines. -See the [Algorithm Guides section](#algorithm-guides) below for guides on sub-packages of +See the [algorithm guides](#algorithm-guides) section below for guides on sub-packages of `spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. -**Table of Contents** +**Table of contents** * This will become a table of contents (this text will be scraped). {:toc} -# Main Concepts +# Main concepts -Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API. +Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple +algorithms into a single pipeline, or workflow. +This section covers the key concepts introduced by the Spark ML API, where the pipeline concept is +mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. -* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL as a dataset which can hold a variety of data types. -E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions. +* **[`DataFrame`](ml-guide.html#dataframe)**: Spark ML uses `DataFrame` from Spark SQL as an ML + dataset, which can hold a variety of data types. + E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. * **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. -E.g., an ML model is a `Transformer` which transforms an RDD with features into an RDD with predictions. +E.g., an ML model is a `Transformer` which transforms `DataFrame` with features into a `DataFrame` with predictions. * **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. -E.g., a learning algorithm is an `Estimator` which trains on a dataset and produces a model. +E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. * **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. -* **[`Param`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. +* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. -## ML Dataset +## DataFrame Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. -Spark ML adopts the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL in order to support a variety of data types under a unified Dataset concept. +Spark ML adopts the `DataFrame` from Spark SQL in order to support a variety of data types. `DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. -In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." -## ML Algorithms +## Pipeline components ### Transformers -A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `DataFrame` into another, generally by appending one or more columns. +A `Transformer` is an abstraction that includes feature transformers and learned models. +Technically, a `Transformer` implements a method `transform()`, which converts one `DataFrame` into +another, generally by appending one or more columns. For example: -* A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset. -* A learning model might take a dataset, read the column containing feature vectors, predict the label for each feature vector, append the labels as a new column, and output the updated dataset. +* A feature transformer might take a `DataFrame`, read a column (e.g., text), map it into a new + column (e.g., feature vectors), and output a new `DataFrame` with the mapped column appended. +* A learning model might take a `DataFrame`, read the column containing feature vectors, predict the + label for each feature vector, and output a new `DataFrame` with predicted labels appended as a + column. ### Estimators -An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `DataFrame` and produces a `Transformer`. -For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling `fit()` trains a `LogisticRegressionModel`, which is a `Transformer`. +An `Estimator` abstracts the concept of a learning algorithm or any algorithm that fits or trains on +data. +Technically, an `Estimator` implements a method `fit()`, which accepts a `DataFrame` and produces a +`Model`, which is a `Transformer`. +For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling +`fit()` trains a `LogisticRegressionModel`, which is a `Model` and hence a `Transformer`. -### Properties of ML Algorithms +### Properties of pipeline components -`Transformer`s and `Estimator`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. +`Transformer.transform()`s and `Estimator.fit()`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. Each instance of a `Transformer` or `Estimator` has a unique ID, which is useful in specifying parameters (discussed below). @@ -91,15 +104,16 @@ E.g., a simple text document processing workflow might include several stages: * Convert each document's words into a numerical feature vector. * Learn a prediction model using the feature vectors and labels. -Spark ML represents such a workflow as a [`Pipeline`](api/scala/index.html#org.apache.spark.ml.Pipeline), -which consists of a sequence of [`PipelineStage`s](api/scala/index.html#org.apache.spark.ml.PipelineStage) (`Transformer`s and `Estimator`s) to be run in a specific order. We will use this simple workflow as a running example in this section. +Spark ML represents such a workflow as a `Pipeline`, which consists of a sequence of +`PipelineStage`s (`Transformer`s and `Estimator`s) to be run in a specific order. +We will use this simple workflow as a running example in this section. -### How It Works +### How it works A `Pipeline` is specified as a sequence of stages, and each stage is either a `Transformer` or an `Estimator`. -These stages are run in order, and the input dataset is modified as it passes through each stage. -For `Transformer` stages, the `transform()` method is called on the dataset. -For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the dataset. +These stages are run in order, and the input `DataFrame` is transformed as it passes through each stage. +For `Transformer` stages, the `transform()` method is called on the `DataFrame`. +For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the `DataFrame`. We illustrate this for the simple text document workflow. The figure below is for the *training time* usage of a `Pipeline`. @@ -115,14 +129,17 @@ We illustrate this for the simple text document workflow. The figure below is f Above, the top row represents a `Pipeline` with three stages. The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. -The `Pipeline.fit()` method is called on the original dataset which has raw text documents and labels. -The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words into the dataset. -The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the dataset. +The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw text documents and labels. +The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. +The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. -If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` method on the dataset before passing the dataset to the next stage. +If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +method on the `DataFrame` before passing the `DataFrame` to the next stage. A `Pipeline` is an `Estimator`. -Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel` which is a `Transformer`. This `PipelineModel` is used at *test time*; the figure below illustrates this usage. +Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel`, which is a +`Transformer`. +This `PipelineModel` is used at *test time*; the figure below illustrates this usage.

    In the figure above, the `PipelineModel` has the same number of stages as the original `Pipeline`, but all `Estimator`s in the original `Pipeline` have become `Transformer`s. -When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed through the `Pipeline` in order. +When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed +through the fitted pipeline in order. Each stage's `transform()` method updates the dataset and passes it to the next stage. `Pipeline`s and `PipelineModel`s help to ensure that training and test data go through identical feature processing steps. @@ -143,40 +161,48 @@ Each stage's `transform()` method updates the dataset and passes it to the next *DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. -*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `DataFrame`. +*Runtime checking*: Since `Pipeline`s can operate on `DataFrame`s with varied types, they cannot use +compile-time type checking. +`Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. +This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. ## Parameters Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. -A [`Param`](api/scala/index.html#org.apache.spark.ml.param.Param) is a named parameter with self-contained documentation. -A [`ParamMap`](api/scala/index.html#org.apache.spark.ml.param.ParamMap) is a set of (parameter, value) pairs. +A `Param` is a named parameter with self-contained documentation. +A `ParamMap` is a set of (parameter, value) pairs. There are two main ways to pass parameters to an algorithm: -1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. This API resembles the API used in MLlib. +1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could + call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. + This API resembles the API used in `spark.mllib` package. 2. Pass a `ParamMap` to `fit()` or `transform()`. Any parameters in the `ParamMap` will override parameters previously specified via setter methods. Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Algorithm Guides +# Algorithm guides There are now several algorithms in the Pipelines API which are not in the `spark.mllib` API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. -**Pipelines API Algorithm Guides** - -* [Feature Extraction, Transformation, and Selection](ml-features.html) -* [Decision Trees for Classification and Regression](ml-decision-tree.html) +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) * [Multilayer perceptron classifier](ml-ann.html) -# Code Examples +# Code examples This section gives code examples illustrating the functionality discussed above. -There is not yet documentation for specific algorithms in Spark ML. For more info, please refer to the [API Documentation](api/scala/index.html#org.apache.spark.ml.package). Spark ML algorithms are currently wrappers for MLlib algorithms, and the [MLlib programming guide](mllib-guide.html) has details on specific algorithms. +For more info, please refer to the API documentation +([Scala](api/scala/index.html#org.apache.spark.ml.package), +[Java](api/java/org/apache/spark/ml/package-summary.html), +and [Python](api/python/pyspark.ml.html)). +Some Spark ML algorithms are wrappers for `spark.mllib` algorithms, and the +[MLlib programming guide](mllib-guide.html) has details on specific algorithms. ## Example: Estimator, Transformer, and Param @@ -627,7 +653,7 @@ sc.stop() -## Example: Model Selection via Cross-Validation +## Example: model selection via cross-validation An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. `Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. @@ -873,11 +899,11 @@ jsc.stop(); -## Example: Model Selection via Train Validation Split +## Example: model selection via train validation split In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. `TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in case of `CrossValidator`. It is therefore less expensive, - but will not produce as reliable results when the training dataset is not sufficiently large.. + but will not produce as reliable results when the training dataset is not sufficiently large. `TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, and an `Evaluator`. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 876dcfd40ed7..257f7cc7603f 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -14,9 +14,9 @@ primitives and higher-level pipeline APIs. It divides into two packages: * [`spark.mllib`](mllib-guide.html#mllib-types-algorithms-and-utilities) contains the original API - built on top of RDDs. + built on top of [RDDs](programming-guide.html#resilient-distributed-datasets-rdds). * [`spark.ml`](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) provides higher-level API - built on top of DataFrames for constructing ML pipelines. + built on top of [DataFrames](sql-programming-guide.html#dataframes) for constructing ML pipelines. Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. But we will keep supporting `spark.mllib` along with the development of `spark.ml`. @@ -57,19 +57,19 @@ We list major functionality from both below, with links to detailed guides. * [FP-growth](mllib-frequent-pattern-mining.html#fp-growth) * [association rules](mllib-frequent-pattern-mining.html#association-rules) * [PrefixSpan](mllib-frequent-pattern-mining.html#prefix-span) -* [Evaluation Metrics](mllib-evaluation-metrics.html) +* [Evaluation metrics](mllib-evaluation-metrics.html) +* [PMML model export](mllib-pmml-model-export.html) * [Optimization (developer)](mllib-optimization.html) * [stochastic gradient descent](mllib-optimization.html#stochastic-gradient-descent-sgd) * [limited-memory BFGS (L-BFGS)](mllib-optimization.html#limited-memory-bfgs-l-bfgs) -* [PMML model export](mllib-pmml-model-export.html) # spark.ml: high-level APIs for ML pipelines **[spark.ml programming guide](ml-guide.html)** provides an overview of the Pipelines API and major concepts. It also contains sections on using algorithms within the Pipelines API, for example: -* [Feature Extraction, Transformation, and Selection](ml-features.html) -* [Decision Trees for Classification and Regression](ml-decision-tree.html) +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) * [Multilayer perceptron classifier](ml-ann.html) From ca69fc8efda8a3e5442ffa16692a2b1eb86b7673 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 29 Aug 2015 23:57:09 -0700 Subject: [PATCH 0483/1824] [SPARK-10331] [MLLIB] Update example code in ml-guide * The example code was added in 1.2, before `createDataFrame`. This PR switches to `createDataFrame`. Java code still uses JavaBean. * assume `sqlContext` is available * fix some minor issues from previous code review jkbradley srowen feynmanliang Author: Xiangrui Meng Closes #8518 from mengxr/SPARK-10331. --- docs/ml-guide.md | 362 +++++++++++++++++++---------------------------- 1 file changed, 147 insertions(+), 215 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 4ba07542bfb4..78c93a95c780 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -212,26 +212,18 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.

    {% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row -val conf = new SparkConf().setAppName("SimpleParamsExample") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training data. -// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into DataFrames, where it uses the case class metadata to infer the schema. -val training = sc.parallelize(Seq( - LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), - LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) +// Prepare training data from a list of (label, features) tuples. +val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) +)).toDF("label", "features") // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() @@ -243,7 +235,7 @@ lr.setMaxIter(10) .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. -val model1 = lr.fit(training.toDF) +val model1 = lr.fit(training) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -253,8 +245,8 @@ println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) -paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. -paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name @@ -262,27 +254,27 @@ val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. -val model2 = lr.fit(training.toDF, paramMapCombined) +val model2 = lr.fit(training, paramMapCombined) println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) // Prepare test data. -val test = sc.parallelize(Seq( - LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) +val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) +)).toDF("label", "features") // Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. -model2.transform(test.toDF) +model2.transform(test) .select("features", "label", "myProbability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => println(s"($features, $label) -> prob=$prob, prediction=$prediction") } -sc.stop() {% endhighlight %}
    @@ -291,30 +283,23 @@ sc.stop() import java.util.Arrays; import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; -SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans // into DataFrames, where it uses the bean metadata to infer the schema. -List localTraining = Arrays.asList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) +), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -334,14 +319,14 @@ LogisticRegressionModel model1 = lr.fit(training); System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. -ParamMap paramMap = new ParamMap(); -paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. -paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. -paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. +ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. // One can also combine ParamMaps. -ParamMap paramMap2 = new ParamMap(); -paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name +ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -350,11 +335,11 @@ LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. -List localTest = Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) +), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. @@ -366,28 +351,21 @@ for (Row r: results.select("features", "label", "myProbability", "prediction").c + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %}
    {% highlight python %} -from pyspark import SparkContext -from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.linalg import Vectors from pyspark.ml.classification import LogisticRegression from pyspark.ml.param import Param, Params -from pyspark.sql import Row, SQLContext -sc = SparkContext(appName="SimpleParamsExample") -sqlContext = SQLContext(sc) - -# Prepare training data. -# We use LabeledPoint. -# Spark SQL can convert RDDs of LabeledPoints into DataFrames. -training = sc.parallelize([LabeledPoint(1.0, [0.0, 1.1, 0.1]), - LabeledPoint(0.0, [2.0, 1.0, -1.0]), - LabeledPoint(0.0, [2.0, 1.3, 1.0]), - LabeledPoint(1.0, [0.0, 1.2, -0.5])]) +# Prepare training data from a list of (label, features) tuples. +training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) # Create a LogisticRegression instance. This instance is an Estimator. lr = LogisticRegression(maxIter=10, regParam=0.01) @@ -395,7 +373,7 @@ lr = LogisticRegression(maxIter=10, regParam=0.01) print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" # Learn a LogisticRegression model. This uses the parameters stored in lr. -model1 = lr.fit(training.toDF()) +model1 = lr.fit(training) # Since model1 is a Model (i.e., a transformer produced by an Estimator), # we can view the parameters it used during fit(). @@ -416,25 +394,25 @@ paramMapCombined.update(paramMap2) # Now learn a new model using the paramMapCombined parameters. # paramMapCombined overrides all parameters set earlier via lr.set* methods. -model2 = lr.fit(training.toDF(), paramMapCombined) +model2 = lr.fit(training, paramMapCombined) print "Model 2 was fit using parameters: " print model2.extractParamMap() # Prepare test data -test = sc.parallelize([LabeledPoint(1.0, [-1.0, 1.5, 1.3]), - LabeledPoint(0.0, [ 3.0, 2.0, -0.1]), - LabeledPoint(1.0, [ 0.0, 2.2, -1.5])]) +test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) # Make predictions on test data using the Transformer.transform() method. # LogisticRegression.transform will only use the 'features' column. # Note that model2.transform() outputs a "myProbability" column instead of the usual # 'probability' column since we renamed the lr.probabilityCol parameter previously. -prediction = model2.transform(test.toDF()) +prediction = model2.transform(test) selected = prediction.select("features", "label", "myProbability", "prediction") for row in selected.collect(): print row -sc.stop() {% endhighlight %}
    @@ -448,30 +426,19 @@ This example follows the simple text document `Pipeline` illustrated in the figu
    {% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from case classes. -case class LabeledDocument(id: Long, text: String, label: Double) -case class Document(id: Long, text: String) +import org.apache.spark.sql.Row -// Set up contexts. Import implicit conversions to DataFrame from sqlContext. -val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training documents, which are labeled. -val training = sc.parallelize(Seq( - LabeledDocument(0L, "a b c d e spark", 1.0), - LabeledDocument(1L, "b d", 0.0), - LabeledDocument(2L, "spark f g h", 1.0), - LabeledDocument(3L, "hadoop mapreduce", 0.0))) +// Prepare training documents from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) +)).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() @@ -488,14 +455,15 @@ val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. -val model = pipeline.fit(training.toDF) +val model = pipeline.fit(training) -// Prepare test documents, which are unlabeled. -val test = sc.parallelize(Seq( - Document(4L, "spark i j k"), - Document(5L, "l m n"), - Document(6L, "mapreduce spark"), - Document(7L, "apache hadoop"))) +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") // Make predictions on test documents. model.transform(test.toDF) @@ -505,7 +473,6 @@ model.transform(test.toDF) println(s"($id, $text) --> prob=$prob, prediction=$prediction") } -sc.stop() {% endhighlight %}
    @@ -514,8 +481,6 @@ sc.stop() import java.util.Arrays; import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; @@ -524,7 +489,6 @@ import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. @@ -556,18 +520,13 @@ public class LabeledDocument extends Document implements Serializable { public void setLabel(double label) { this.label = label; } } -// Set up contexts. -SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - // Prepare training documents, which are labeled. -List localTraining = Arrays.asList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0)); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + new LabeledDocument(3L, "hadoop mapreduce", 0.0) +), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -587,12 +546,12 @@ Pipeline pipeline = new Pipeline() PipelineModel model = pipeline.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Arrays.asList( +DataFrame test = sqlContext.createDataFrame(Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + new Document(7L, "apache hadoop") +), Document.class); // Make predictions on test documents. DataFrame predictions = model.transform(test); @@ -601,28 +560,23 @@ for (Row r: predictions.select("id", "text", "probability", "prediction").collec + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %}
    {% highlight python %} -from pyspark import SparkContext from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row, SQLContext - -sc = SparkContext(appName="SimpleTextClassificationPipeline") -sqlContext = SQLContext(sc) +from pyspark.sql import Row -# Prepare training documents, which are labeled. +# Prepare training documents from a list of (id, text, label) tuples. LabeledDocument = Row("id", "text", "label") -training = sc.parallelize([(0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)]) \ - .map(lambda x: LabeledDocument(*x)).toDF() +training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. tokenizer = Tokenizer(inputCol="text", outputCol="words") @@ -633,13 +587,12 @@ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. model = pipeline.fit(training) -# Prepare test documents, which are unlabeled. -Document = Row("id", "text") -test = sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) \ - .map(lambda x: Document(*x)).toDF() +# Prepare test documents, which are unlabeled (id, text) tuples. +test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) # Make predictions on test documents and print columns of interest. prediction = model.transform(test) @@ -647,7 +600,6 @@ selected = prediction.select("id", "text", "prediction") for row in selected.collect(): print(row) -sc.stop() {% endhighlight %}
    @@ -664,8 +616,8 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/ The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) -for binary data or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the setMetric +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` method in each of these evaluators. The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. @@ -684,39 +636,29 @@ However, it is also a well-established method for choosing parameters which is m
    {% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from case classes. -case class LabeledDocument(id: Long, text: String, label: Double) -case class Document(id: Long, text: String) - -val conf = new SparkConf().setAppName("CrossValidatorExample") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training documents, which are labeled. -val training = sc.parallelize(Seq( - LabeledDocument(0L, "a b c d e spark", 1.0), - LabeledDocument(1L, "b d", 0.0), - LabeledDocument(2L, "spark f g h", 1.0), - LabeledDocument(3L, "hadoop mapreduce", 0.0), - LabeledDocument(4L, "b spark who", 1.0), - LabeledDocument(5L, "g d a y", 0.0), - LabeledDocument(6L, "spark fly", 1.0), - LabeledDocument(7L, "was mapreduce", 0.0), - LabeledDocument(8L, "e spark program", 1.0), - LabeledDocument(9L, "a e c l", 0.0), - LabeledDocument(10L, "spark compile", 1.0), - LabeledDocument(11L, "hadoop software", 0.0))) +import org.apache.spark.sql.Row + +// Prepare training data from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) +)).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() @@ -730,15 +672,6 @@ val lr = new LogisticRegression() val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric -// used is areaUnderROC. -val crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) - // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -746,28 +679,37 @@ val paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) .addGrid(lr.regParam, Array(0.1, 0.01)) .build() -crossval.setEstimatorParamMaps(paramGrid) -crossval.setNumFolds(2) // Use 3+ in practice + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -val cvModel = crossval.fit(training.toDF) +val cvModel = cv.fit(training) -// Prepare test documents, which are unlabeled. -val test = sc.parallelize(Seq( - Document(4L, "spark i j k"), - Document(5L, "l m n"), - Document(6L, "mapreduce spark"), - Document(7L, "apache hadoop"))) +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test.toDF) +cvModel.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") -} + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } -sc.stop() {% endhighlight %}
    @@ -776,8 +718,6 @@ sc.stop() import java.util.Arrays; import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; @@ -790,7 +730,6 @@ import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. @@ -822,12 +761,9 @@ public class LabeledDocument extends Document implements Serializable { public void setLabel(double label) { this.label = label; } } -SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. -List localTraining = Arrays.asList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), @@ -839,8 +775,8 @@ List localTraining = Arrays.asList( new LabeledDocument(8L, "e spark program", 1.0), new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0)); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + new LabeledDocument(11L, "hadoop software", 0.0) +), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -856,15 +792,6 @@ LogisticRegression lr = new LogisticRegression() Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric -// used is areaUnderROC. -CrossValidator crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()); - // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -872,19 +799,28 @@ ParamMap[] paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) .addGrid(lr.regParam(), new double[]{0.1, 0.01}) .build(); -crossval.setEstimatorParamMaps(paramGrid); -crossval.setNumFolds(2); // Use 3+ in practice + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2); // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = crossval.fit(training); +CrossValidatorModel cvModel = cv.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Arrays.asList( +DataFrame test = sqlContext.createDataFrame(Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + new Document(7L, "apache hadoop") +), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). DataFrame predictions = cvModel.transform(test); @@ -893,7 +829,6 @@ for (Row r: predictions.select("id", "text", "probability", "prediction").collec + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %} @@ -935,7 +870,7 @@ val lr = new LinearRegression() // the evaluator. val paramGrid = new ParamGridBuilder() .addGrid(lr.regParam, Array(0.1, 0.01)) - .addGrid(lr.fitIntercept, Array(true, false)) + .addGrid(lr.fitIntercept) .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) .build() @@ -945,9 +880,8 @@ val trainValidationSplit = new TrainValidationSplit() .setEstimator(lr) .setEvaluator(new RegressionEvaluator) .setEstimatorParamMaps(paramGrid) - -// 80% of the data will be used for training and the remaining 20% for validation. -trainValidationSplit.setTrainRatio(0.8) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) // Run train validation split, and choose the best set of parameters. val model = trainValidationSplit.fit(training) @@ -972,12 +906,12 @@ import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; -DataFrame data = jsql.createDataFrame( +DataFrame data = sqlContext.createDataFrame( MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), LabeledPoint.class); // Prepare training and test data. -DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); +DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); DataFrame training = splits[0]; DataFrame test = splits[1]; @@ -997,10 +931,8 @@ ParamMap[] paramGrid = new ParamGridBuilder() TrainValidationSplit trainValidationSplit = new TrainValidationSplit() .setEstimator(lr) .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid); - -// 80% of the data will be used for training and the remaining 20% for validation. -trainValidationSplit.setTrainRatio(0.8); + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation // Run train validation split, and choose the best set of parameters. TrainValidationSplitModel model = trainValidationSplit.fit(training); From 1bfd9347822df65e76201c4c471a26488d722319 Mon Sep 17 00:00:00 2001 From: ihainan Date: Sun, 30 Aug 2015 08:26:14 +0100 Subject: [PATCH 0484/1824] [SPARK-10184] [CORE] Optimization for bounds determination in RangePartitioner JIRA Issue: https://issues.apache.org/jira/browse/SPARK-10184 Change `cumWeight > target` to `cumWeight >= target` in `RangePartitioner.determineBounds` method to make the output partitions more balanced. Author: ihainan Closes #8397 from ihainan/opt_for_rangepartitioner. --- core/src/main/scala/org/apache/spark/Partitioner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 4b9d59975bdc..29e581bb57cb 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -291,7 +291,7 @@ private[spark] object RangePartitioner { while ((i < numCandidates) && (j < partitions - 1)) { val (key, weight) = ordered(i) cumWeight += weight - if (cumWeight > target) { + if (cumWeight >= target) { // Skip duplicate values. if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { bounds += key From 8d2ab75d3b71b632f2394f2453af32f417cb45e5 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 30 Aug 2015 12:21:15 -0700 Subject: [PATCH 0485/1824] [SPARK-10353] [MLLIB] BLAS gemm not scaling when beta = 0.0 for some subset of matrix multiplications mengxr jkbradley rxin It would be great if this fix made it into RC3! Author: Burak Yavuz Closes #8525 from brkyvz/blas-scaling. --- .../org/apache/spark/mllib/linalg/BLAS.scala | 26 +++++++------------ .../apache/spark/mllib/linalg/BLASSuite.scala | 5 ++++ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index bbbcc8436b7c..ab475af264dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -305,6 +305,8 @@ private[spark] object BLAS extends Serializable with Logging { "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0 && beta == 1.0) { logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.") + } else if (alpha == 0.0) { + f2jBLAS.dscal(C.values.length, beta, C.values, 1) } else { A match { case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) @@ -408,8 +410,8 @@ private[spark] object BLAS extends Serializable with Logging { } } } else { - // Scale matrix first if `beta` is not equal to 0.0 - if (beta != 0.0) { + // Scale matrix first if `beta` is not equal to 1.0 + if (beta != 1.0) { f2jBLAS.dscal(C.values.length, beta, C.values, 1) } // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of @@ -470,8 +472,10 @@ private[spark] object BLAS extends Serializable with Logging { s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") require(A.numRows == y.size, s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}") - if (alpha == 0.0) { - logDebug("gemv: alpha is equal to 0. Returning y.") + if (alpha == 0.0 && beta == 1.0) { + logDebug("gemv: alpha is equal to 0 and beta is equal to 1. Returning y.") + } else if (alpha == 0.0) { + scal(beta, y) } else { (A, x) match { case (smA: SparseMatrix, dvx: DenseVector) => @@ -526,11 +530,6 @@ private[spark] object BLAS extends Serializable with Logging { val xValues = x.values val yValues = y.values - if (alpha == 0.0) { - scal(beta, y) - return - } - if (A.isTransposed) { var rowCounterForA = 0 while (rowCounterForA < mA) { @@ -581,11 +580,6 @@ private[spark] object BLAS extends Serializable with Logging { val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices - if (alpha == 0.0) { - scal(beta, y) - return - } - if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { @@ -604,7 +598,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - scal(beta, y) + if (beta != 1.0) scal(beta, y) var colCounterForA = 0 var k = 0 @@ -659,7 +653,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - scal(beta, y) + if (beta != 1.0) scal(beta, y) // Perform matrix-vector multiplication and add to y var colCounterForA = 0 while (colCounterForA < nA) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index d119e0b50a39..8db5c8424abe 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -204,6 +204,7 @@ class BLASSuite extends SparkFunSuite { val C14 = C1.copy val C15 = C1.copy val C16 = C1.copy + val C17 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0)) @@ -217,6 +218,10 @@ class BLASSuite extends SparkFunSuite { assert(C2 ~== expected2 absTol 1e-15) assert(C3 ~== expected3 absTol 1e-15) assert(C4 ~== expected3 absTol 1e-15) + gemm(1.0, dA, B, 0.0, C17) + assert(C17 ~== expected absTol 1e-15) + gemm(1.0, sA, B, 0.0, C17) + assert(C17 ~== expected absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] { From 35e896a79bb5e72d63b82b047f46f4f6fa2e1970 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 30 Aug 2015 21:39:16 -0700 Subject: [PATCH 0486/1824] SPARK-9545, SPARK-9547: Use Maven in PRB if title contains "[test-maven]" This is just some small glue code to actually make use of the AMPLAB_JENKINS_BUILD_TOOL switch. As far as I can tell, we actually don't currently use the Maven support in the tool even though it exists. This patch switches to Maven when the PR title contains "test-maven". There are a few small other pieces of cleanup in the patch as well. Author: Patrick Wendell Closes #7878 from pwendell/maven-tests. --- dev/run-tests-jenkins | 18 ++++++++++++++++-- dev/run-tests.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 39cf54f78104..3be78575e70f 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -164,8 +164,9 @@ pr_message="" current_pr_head="`git rev-parse HEAD`" echo "HEAD: `git rev-parse HEAD`" -echo "GHPRB: $ghprbActualCommit" -echo "SHA1: $sha1" +echo "\$ghprbActualCommit: $ghprbActualCommit" +echo "\$sha1: $sha1" +echo "\$ghprbPullTitle: $ghprbPullTitle" # Run pull request tests for t in "${PR_TESTS[@]}"; do @@ -189,6 +190,19 @@ done { # Marks this build is a pull request build. export AMP_JENKINS_PRB=true + if [[ $ghprbPullTitle == *"test-maven"* ]]; then + export AMPLAB_JENKINS_BUILD_TOOL="maven" + fi + if [[ $ghprbPullTitle == *"test-hadoop1.0"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop1.0" + elif [[ $ghprbPullTitle == *"test-hadoop2.0"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.0" + elif [[ $ghprbPullTitle == *"test-hadoop2.2"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.2" + elif [[ $ghprbPullTitle == *"test-hadoop2.3"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.3" + fi + timeout "${TESTS_TIMEOUT}" ./dev/run-tests test_result="$?" diff --git a/dev/run-tests.py b/dev/run-tests.py index 4fd703a7c219..d8b22e1665e7 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -21,6 +21,7 @@ import itertools from optparse import OptionParser import os +import random import re import sys import subprocess @@ -239,11 +240,32 @@ def build_spark_documentation(): os.chdir(SPARK_HOME) +def get_zinc_port(): + """ + Get a randomized port on which to start Zinc + """ + return random.randrange(3030, 4030) + + +def kill_zinc_on_port(zinc_port): + """ + Kill the Zinc process running on the given port, if one exists. + """ + cmd = ("/usr/sbin/lsof -P |grep %s | grep LISTEN " + "| awk '{ print $2; }' | xargs kill") % zinc_port + subprocess.check_call(cmd, shell=True) + + def exec_maven(mvn_args=()): """Will call Maven in the current directory with the list of mvn_args passed in and returns the subprocess for any further processing""" - run_cmd([os.path.join(SPARK_HOME, "build", "mvn")] + mvn_args) + zinc_port = get_zinc_port() + os.environ["ZINC_PORT"] = "%s" % zinc_port + zinc_flag = "-DzincPort=%s" % zinc_port + flags = [os.path.join(SPARK_HOME, "build", "mvn"), "--force", zinc_flag] + run_cmd(flags + mvn_args) + kill_zinc_on_port(zinc_port) def exec_sbt(sbt_args=()): @@ -514,7 +536,9 @@ def main(): build_apache_spark(build_tool, hadoop_version) # backwards compatibility checks - detect_binary_inop_with_mima() + if build_tool == "sbt": + # Note: compatiblity tests only supported in sbt for now + detect_binary_inop_with_mima() # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules) From 8694c3ad7dcafca9563649e93b7a08076748d6f2 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Sun, 30 Aug 2015 23:12:56 -0700 Subject: [PATCH 0487/1824] [SPARK-10351] [SQL] Fixes UTF8String.fromAddress to handle off-heap memory CC rxin marmbrus Author: Feynman Liang Closes #8523 from feynmanliang/SPARK-10351. --- .../test/scala/org/apache/spark/sql/UnsafeRowSuite.scala | 9 +++++---- .../java/org/apache/spark/unsafe/types/UTF8String.java | 6 +----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 219435dff5bc..2476b10e3cf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -43,12 +43,12 @@ class UnsafeRowSuite extends SparkFunSuite { val arrayBackedUnsafeRow: UnsafeRow = UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) - val bytesFromArrayBackedRow: Array[Byte] = { + val (bytesFromArrayBackedRow, field0StringFromArrayBackedRow): (Array[Byte], String) = { val baos = new ByteArrayOutputStream() arrayBackedUnsafeRow.writeToStream(baos, null) - baos.toByteArray + (baos.toByteArray, arrayBackedUnsafeRow.getString(0)) } - val bytesFromOffheapRow: Array[Byte] = { + val (bytesFromOffheapRow, field0StringFromOffheapRow): (Array[Byte], String) = { val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) try { Platform.copyMemory( @@ -69,13 +69,14 @@ class UnsafeRowSuite extends SparkFunSuite { val baos = new ByteArrayOutputStream() val writeBuffer = new Array[Byte](1024) offheapUnsafeRow.writeToStream(baos, writeBuffer) - baos.toByteArray + (baos.toByteArray, offheapUnsafeRow.getString(0)) } finally { MemoryAllocator.UNSAFE.free(offheapRowPage) } } assert(bytesFromArrayBackedRow === bytesFromOffheapRow) + assert(field0StringFromArrayBackedRow === field0StringFromOffheapRow) } test("calling getDouble() and getFloat() on null columns") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index cbcab958c05a..216aeea60d1c 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -90,11 +90,7 @@ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { * Creates an UTF8String from given address (base and offset) and length. */ public static UTF8String fromAddress(Object base, long offset, int numBytes) { - if (base != null) { - return new UTF8String(base, offset, numBytes); - } else { - return null; - } + return new UTF8String(base, offset, numBytes); } /** From f0f563a3c43fc9683e6920890cce44611c0c5f4b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 30 Aug 2015 23:20:03 -0700 Subject: [PATCH 0488/1824] [SPARK-100354] [MLLIB] fix some apparent memory issues in k-means|| initializaiton * do not cache first cost RDD * change following cost RDD cache level to MEMORY_AND_DISK * remove Vector wrapper to save a object per instance Further improvements will be addressed in SPARK-10329 cc: yu-iskw HuJiayin Author: Xiangrui Meng Closes #8526 from mengxr/SPARK-10354. --- .../spark/mllib/clustering/KMeans.scala | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 46920fffe6e1..7168aac32c99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -369,7 +369,7 @@ class KMeans private ( : Array[Array[VectorWithNorm]] = { // Initialize empty centers and point costs. val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) - var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity)) // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() @@ -394,21 +394,28 @@ class KMeans private ( val bcNewCenters = data.context.broadcast(newCenters) val preCosts = costs costs = data.zip(preCosts).map { case (point, cost) => - Vectors.dense( Array.tabulate(runs) { r => math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) - }) - }.cache() + } + }.persist(StorageLevel.MEMORY_AND_DISK) val sumCosts = costs - .aggregate(Vectors.zeros(runs))( + .aggregate(new Array[Double](runs))( seqOp = (s, v) => { // s += v - axpy(1.0, v, s) + var r = 0 + while (r < runs) { + s(r) += v(r) + r += 1 + } s }, combOp = (s0, s1) => { // s0 += s1 - axpy(1.0, s1, s0) + var r = 0 + while (r < runs) { + s0(r) += s1(r) + r += 1 + } s0 } ) From 72f6dbf7b0c8b271f5f9c762374422c69c8ab43d Mon Sep 17 00:00:00 2001 From: EugenCepoi Date: Mon, 31 Aug 2015 13:24:35 -0500 Subject: [PATCH 0489/1824] [SPARK-8730] Fixes - Deser objects containing a primitive class attribute Author: EugenCepoi Closes #7122 from EugenCepoi/master. --- .../spark/serializer/JavaSerializer.scala | 27 +++++++++++++++---- .../serializer/JavaSerializerSuite.scala | 18 +++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 4a5274b46b7a..b463a71d5bd7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -62,17 +62,34 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa extends DeserializationStream { private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = { - // scalastyle:off classforname - Class.forName(desc.getName, false, loader) - // scalastyle:on classforname - } + override def resolveClass(desc: ObjectStreamClass): Class[_] = + try { + // scalastyle:off classforname + Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } catch { + case e: ClassNotFoundException => + JavaDeserializationStream.primitiveMappings.get(desc.getName).getOrElse(throw e) + } } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] def close() { objIn.close() } } +private object JavaDeserializationStream { + val primitiveMappings = Map[String, Class[_]]( + "boolean" -> classOf[Boolean], + "byte" -> classOf[Byte], + "char" -> classOf[Char], + "short" -> classOf[Short], + "int" -> classOf[Int], + "long" -> classOf[Long], + "float" -> classOf[Float], + "double" -> classOf[Double], + "void" -> classOf[Void] + ) +} private[spark] class JavaSerializerInstance( counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader) diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index 329a2b6dad83..20f45670bc2b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -25,4 +25,22 @@ class JavaSerializerSuite extends SparkFunSuite { val instance = serializer.newInstance() instance.deserialize[JavaSerializer](instance.serialize(serializer)) } + + test("Deserialize object containing a primitive Class as attribute") { + val serializer = new JavaSerializer(new SparkConf()) + val instance = serializer.newInstance() + instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass())) + } +} + +private class ContainsPrimitiveClass extends Serializable { + val intClass = classOf[Int] + val longClass = classOf[Long] + val shortClass = classOf[Short] + val charClass = classOf[Char] + val doubleClass = classOf[Double] + val floatClass = classOf[Float] + val booleanClass = classOf[Boolean] + val byteClass = classOf[Byte] + val voidClass = classOf[Void] } From 4a5fe091658b1d06f427e404a11a84fc84f953c5 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 31 Aug 2015 12:19:11 -0700 Subject: [PATCH 0490/1824] [SPARK-10369] [STREAMING] Don't remove ReceiverTrackingInfo when deregisterReceivering since we may reuse it later `deregisterReceiver` should not remove `ReceiverTrackingInfo`. Otherwise, it will throw `java.util.NoSuchElementException: key not found` when restarting it. Author: zsxwing Closes #8538 from zsxwing/SPARK-10369. --- .../streaming/scheduler/ReceiverTracker.scala | 4 +- .../scheduler/ReceiverTrackerSuite.scala | 51 +++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 3d532a675db0..f86fd44b4871 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -291,7 +291,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ReceiverTrackingInfo( streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverTrackingInfos -= streamId + receiverTrackingInfos(streamId) = newReceiverTrackingInfo listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" @@ -483,7 +483,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false context.reply(true) // Local messages case AllReceiverIds => - context.reply(receiverTrackingInfos.keys.toSeq) + context.reply(receiverTrackingInfos.filter(_._2.state != ReceiverState.INACTIVE).keys.toSeq) case StopAllReceivers => assert(isTrackerStopping || isTrackerStopped) stopReceivers() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index dd292ba4dd94..45138b748eca 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -60,6 +60,26 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } } + + test("should restart receiver after stopping it") { + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + @volatile var startTimes = 0 + ssc.addStreamingListener(new StreamingListener { + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + startTimes += 1 + } + }) + val input = ssc.receiverStream(new StoppableReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + StoppableReceiver.shouldStop = true + eventually(timeout(10 seconds), interval(10 millis)) { + // The receiver is stopped once, so if it's restarted, it should be started twice. + assert(startTimes === 2) + } + } + } } /** An input DStream with for testing rate controlling */ @@ -132,3 +152,34 @@ private[streaming] object RateTestReceiver { def getActive(): Option[RateTestReceiver] = Option(activeReceiver) } + +/** + * A custom receiver that could be stopped via StoppableReceiver.shouldStop + */ +class StoppableReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + var receivingThreadOption: Option[Thread] = None + + def onStart() { + val thread = new Thread() { + override def run() { + while (!StoppableReceiver.shouldStop) { + Thread.sleep(10) + } + StoppableReceiver.this.stop("stop") + } + } + thread.start() + } + + def onStop() { + StoppableReceiver.shouldStop = true + receivingThreadOption.foreach(_.join()) + // Reset it so as to restart it + StoppableReceiver.shouldStop = false + } +} + +object StoppableReceiver { + @volatile var shouldStop = false +} From a2d5c72091b1c602694dbca823a7b26f86b02864 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Mon, 31 Aug 2015 12:39:58 -0700 Subject: [PATCH 0491/1824] [SPARK-10170] [SQL] Add DB2 JDBC dialect support. Data frame write to DB2 database is failing because by default JDBC data source implementation is generating a table schema with DB2 unsupported data types TEXT for String, and BIT1(1) for Boolean. This patch registers DB2 JDBC Dialect that maps String, Boolean to valid DB2 data types. Author: sureshthalamati Closes #8393 from sureshthalamati/db2_dialect_spark-10170. --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 7 +++++++ 2 files changed, 25 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 8849fc2f1f0e..c6d05c9b83b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -125,6 +125,7 @@ object JdbcDialects { registerDialect(MySQLDialect) registerDialect(PostgresDialect) + registerDialect(DB2Dialect) /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -222,3 +223,20 @@ case object MySQLDialect extends JdbcDialect { s"`$colName`" } } + +/** + * :: DeveloperApi :: + * Default DB2 dialect, mapping string/boolean on write to valid DB2 types. + * By default string, and boolean gets mapped to db2 invalid types TEXT, and BIT(1). + */ +@DeveloperApi +case object DB2Dialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) + case BooleanType => Some(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case _ => None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0edac0848c3b..d8c9a08d84c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -407,6 +407,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("Default jdbc dialect registration") { assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) + assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } @@ -443,4 +444,10 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } + + test("DB2Dialect type mapping") { + val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") + assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + } } From 23e39cc7b1bb7f1087c4706234c9b5165a571357 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 31 Aug 2015 15:49:25 -0700 Subject: [PATCH 0492/1824] [SPARK-9954] [MLLIB] use first 128 nonzeros to compute Vector.hashCode This could help reduce hash collisions, e.g., in `RDD[Vector].repartition`. jkbradley Author: Xiangrui Meng Closes #8182 from mengxr/SPARK-9954. --- .../apache/spark/mllib/linalg/Vectors.scala | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 06ebb1586990..3642e9286504 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -71,20 +71,22 @@ sealed trait Vector extends Serializable { } /** - * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros - * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + * Returns a hash code value for the vector. The hash code is based on its size and its first 128 + * nonzero entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. */ override def hashCode(): Int = { // This is a reference implementation. It calls return in foreachActive, which is slow. // Subclasses should override it with optimized implementation. var result: Int = 31 + size + var nnz = 0 this.foreachActive { (index, value) => - if (index < 16) { + if (nnz < Vectors.MAX_HASH_NNZ) { // ignore explicit 0 for comparison between sparse and dense if (value != 0) { result = 31 * result + index val bits = java.lang.Double.doubleToLongBits(value) result = 31 * result + (bits ^ (bits >>> 32)).toInt + nnz += 1 } } else { return result @@ -536,6 +538,9 @@ object Vectors { } allEqual } + + /** Max number of nonzero entries used in computing hash code. */ + private[linalg] val MAX_HASH_NNZ = 128 } /** @@ -578,13 +583,15 @@ class DenseVector @Since("1.0.0") ( override def hashCode(): Int = { var result: Int = 31 + size var i = 0 - val end = math.min(values.length, 16) - while (i < end) { + val end = values.length + var nnz = 0 + while (i < end && nnz < Vectors.MAX_HASH_NNZ) { val v = values(i) if (v != 0.0) { result = 31 * result + i val bits = java.lang.Double.doubleToLongBits(values(i)) result = 31 * result + (bits ^ (bits >>> 32)).toInt + nnz += 1 } i += 1 } @@ -707,19 +714,16 @@ class SparseVector @Since("1.0.0") ( override def hashCode(): Int = { var result: Int = 31 + size val end = values.length - var continue = true var k = 0 - while ((k < end) & continue) { - val i = indices(k) - if (i < 16) { - val v = values(k) - if (v != 0.0) { - result = 31 * result + i - val bits = java.lang.Double.doubleToLongBits(v) - result = 31 * result + (bits ^ (bits >>> 32)).toInt - } - } else { - continue = false + var nnz = 0 + while (k < end && nnz < Vectors.MAX_HASH_NNZ) { + val v = values(k) + if (v != 0.0) { + val i = indices(k) + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + nnz += 1 } k += 1 } From 5b3245d6dff65972fc39c73f90d5cbdf84d19129 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 31 Aug 2015 15:50:41 -0700 Subject: [PATCH 0493/1824] [SPARK-8472] [ML] [PySpark] Python API for DCT Add Python API for ml.feature.DCT. Author: Yanbo Liang Closes #8485 from yanboliang/spark-8472. --- python/pyspark/ml/feature.py | 65 +++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 04b2b2ccc9e5..59300a607815 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,7 +26,7 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', +__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel', @@ -166,6 +166,69 @@ def getSplits(self): return self.getOrDefault(self.splits) +@inherit_doc +class DCT(JavaTransformer, HasInputCol, HasOutputCol): + """ + A feature transformer that takes the 1D discrete cosine transform + of a real vector. No zero padding is performed on the input vector. + It returns a real vector of the same length representing the DCT. + The return vector is scaled such that the transform matrix is + unitary (aka scaled DCT-II). + + More information on + `https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia`. + + >>> from pyspark.mllib.linalg import Vectors + >>> df1 = sqlContext.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"]) + >>> dct = DCT(inverse=False, inputCol="vec", outputCol="resultVec") + >>> df2 = dct.transform(df1) + >>> df2.head().resultVec + DenseVector([10.969..., -0.707..., -2.041...]) + >>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").transform(df2) + >>> df3.head().origVec + DenseVector([5.0, 8.0, 6.0]) + """ + + # a placeholder to make it appear in the generated doc + inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " + + "default False.") + + @keyword_only + def __init__(self, inverse=False, inputCol=None, outputCol=None): + """ + __init__(self, inverse=False, inputCol=None, outputCol=None) + """ + super(DCT, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) + self.inverse = Param(self, "inverse", "Set transformer to perform inverse DCT, " + + "default False.") + self._setDefault(inverse=False) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inverse=False, inputCol=None, outputCol=None): + """ + setParams(self, inverse=False, inputCol=None, outputCol=None) + Sets params for this DCT. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setInverse(self, value): + """ + Sets the value of :py:attr:`inverse`. + """ + self._paramMap[self.inverse] = value + return self + + def getInverse(self): + """ + Gets the value of inverse or its default value. + """ + return self.getOrDefault(self.inverse) + + @inherit_doc class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): """ From 540bdee93103a73736d282b95db6a8cda8f6a2b1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 31 Aug 2015 15:55:22 -0700 Subject: [PATCH 0494/1824] [SPARK-10341] [SQL] fix memory starving in unsafe SMJ In SMJ, the first ExternalSorter could consume all the memory before spilling, then the second can not even acquire the first page. Before we have a better memory allocator, SMJ should call prepare() before call any compute() of it's children. cc rxin JoshRosen Author: Davies Liu Closes #8511 from davies/smj_memory. --- .../rdd/MapPartitionsWithPreparationRDD.scala | 21 +++++++++++++++++-- .../spark/rdd/ZippedPartitionsRDD.scala | 13 ++++++++++++ ...MapPartitionsWithPreparationRDDSuite.scala | 14 +++++++++---- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala index b475bd8d79f8..1f2213d0c434 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark.{Partition, Partitioner, TaskContext} @@ -38,12 +39,28 @@ private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M override def getPartitions: Array[Partition] = firstParent[T].partitions + // In certain join operations, prepare can be called on the same partition multiple times. + // In this case, we need to ensure that each call to compute gets a separate prepare argument. + private[this] var preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M] + + /** + * Prepare a partition for a single call to compute. + */ + def prepare(): Unit = { + preparedArguments += preparePartition() + } + /** * Prepare a partition before computing it from its parent. */ override def compute(partition: Partition, context: TaskContext): Iterator[U] = { - val preparedArgument = preparePartition() + val prepared = + if (preparedArguments.isEmpty) { + preparePartition() + } else { + preparedArguments.remove(0) + } val parentIterator = firstParent[T].iterator(partition, context) - executePartition(context, partition.index, preparedArgument, parentIterator) + executePartition(context, partition.index, prepared, parentIterator) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 81f40ad33aa5..b3c64394abc7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -73,6 +73,16 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( super.clearDependencies() rdds = null } + + /** + * Call the prepare method of every parent that has one. + * This is needed for reserving execution memory in advance. + */ + protected def tryPrepareParents(): Unit = { + rdds.collect { + case rdd: MapPartitionsWithPreparationRDD[_, _, _] => rdd.prepare() + } + } } private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( @@ -84,6 +94,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag] extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { + tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) } @@ -107,6 +118,7 @@ private[spark] class ZippedPartitionsRDD3 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { + tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), @@ -134,6 +146,7 @@ private[spark] class ZippedPartitionsRDD4 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { + tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala index c16930e7d649..e281e817e493 100644 --- a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala @@ -46,11 +46,17 @@ class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSpark } // Verify that the numbers are pushed in the order expected - val result = { - new MapPartitionsWithPreparationRDD[Int, Int, Unit]( - parent, preparePartition, executePartition).collect() - } + val rdd = new MapPartitionsWithPreparationRDD[Int, Int, Unit]( + parent, preparePartition, executePartition) + val result = rdd.collect() assert(result === Array(10, 20, 30)) + + TestObject.things.clear() + // Zip two of these RDDs, both should be prepared before the parent is executed + val rdd2 = new MapPartitionsWithPreparationRDD[Int, Int, Unit]( + parent, preparePartition, executePartition) + val result2 = rdd.zipPartitions(rdd2)((a, b) => a).collect() + assert(result2 === Array(10, 10, 20, 30, 20, 30)) } } From fe16fd0b8b717f01151bc659ec3299dab091c97a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 31 Aug 2015 16:06:38 -0700 Subject: [PATCH 0495/1824] [SPARK-10349] [ML] OneVsRest use 'when ... otherwise' not UDF to generate new label at binary reduction Currently OneVsRest use UDF to generate new binary label during training. Considering that [SPARK-7321](https://issues.apache.org/jira/browse/SPARK-7321) has been merged, we can use ```when ... otherwise``` which will be more efficiency. Author: Yanbo Liang Closes #8519 from yanboliang/spark-10349. --- .../org/apache/spark/ml/classification/OneVsRest.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index c62e132f5d53..debc164bf243 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -91,7 +91,6 @@ final class OneVsRestModel private[ml] ( // add an accumulator column to store predictions of all the models val accColName = "mbc$acc" + UUID.randomUUID().toString val initUDF = udf { () => Map[Int, Double]() } - val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. @@ -195,16 +194,11 @@ final class OneVsRest(override val uid: String) // create k columns, one for each binary classifier. val models = Range(0, numClasses).par.map { index => - val labelUDF = udf { (label: Double) => - if (label.toInt == index) 1.0 else 0.0 - } - // generate new label metadata for the binary problem. - // TODO: use when ... otherwise after SPARK-7321 is merged val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val trainingDataset = - multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta) + val trainingDataset = multiclassLabeled.withColumn( + labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta) val classifier = getClassifier val paramMap = new ParamMap() paramMap.put(classifier.labelCol -> labelColName) From 52ea399e6ee37b7c44aae7709863e006fca88906 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 31 Aug 2015 16:11:27 -0700 Subject: [PATCH 0496/1824] [SPARK-10355] [ML] [PySpark] Add Python API for SQLTransformer Add Python API for SQLTransformer Author: Yanbo Liang Closes #8527 from yanboliang/spark-10355. --- python/pyspark/ml/feature.py | 57 ++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 59300a607815..0626281e200a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -28,9 +28,9 @@ __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', - 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', - 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel', - 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] + 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', + 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', + 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] @inherit_doc @@ -743,6 +743,57 @@ def getPattern(self): return self.getOrDefault(self.pattern) +@inherit_doc +class SQLTransformer(JavaTransformer): + """ + Implements the transforms which are defined by SQL statement. + Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + where '__THIS__' represents the underlying table of the input dataset. + + >>> df = sqlContext.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"]) + >>> sqlTrans = SQLTransformer( + ... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + >>> sqlTrans.transform(df).head() + Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0) + """ + + # a placeholder to make it appear in the generated doc + statement = Param(Params._dummy(), "statement", "SQL statement") + + @keyword_only + def __init__(self, statement=None): + """ + __init__(self, statement=None) + """ + super(SQLTransformer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) + self.statement = Param(self, "statement", "SQL statement") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, statement=None): + """ + setParams(self, statement=None) + Sets params for this SQLTransformer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setStatement(self, value): + """ + Sets the value of :py:attr:`statement`. + """ + self._paramMap[self.statement] = value + return self + + def getStatement(self): + """ + Gets the value of statement or its default value. + """ + return self.getOrDefault(self.statement) + + @inherit_doc class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): """ From d65656c455d19b83c6412571873586b458aa355e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 31 Aug 2015 18:09:24 -0700 Subject: [PATCH 0497/1824] [SPARK-10378][SQL][Test] Remove HashJoinCompatibilitySuite. They don't bring much value since we now have better unit test coverage for hash joins. This will also help reduce the test time. Author: Reynold Xin Closes #8542 from rxin/SPARK-10378. --- .../HashJoinCompatibilitySuite.scala | 169 ------------------ 1 file changed, 169 deletions(-) delete mode 100644 sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala deleted file mode 100644 index 1a5ba20404c4..000000000000 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import java.io.File - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.test.TestHive - -/** - * Runs the test cases that are included in the hive distribution with hash joins. - */ -class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { - override def beforeAll() { - super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) - } - - override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) - super.afterAll() - } - - override def whiteList = Seq( - "auto_join0", - "auto_join1", - "auto_join10", - "auto_join11", - "auto_join12", - "auto_join13", - "auto_join14", - "auto_join14_hadoop20", - "auto_join15", - "auto_join17", - "auto_join18", - "auto_join19", - "auto_join2", - "auto_join20", - "auto_join21", - "auto_join22", - "auto_join23", - "auto_join24", - "auto_join25", - "auto_join26", - "auto_join27", - "auto_join28", - "auto_join3", - "auto_join30", - "auto_join31", - "auto_join32", - "auto_join4", - "auto_join5", - "auto_join6", - "auto_join7", - "auto_join8", - "auto_join9", - "auto_join_filters", - "auto_join_nulls", - "auto_join_reordering_values", - "auto_smb_mapjoin_14", - "auto_sortmerge_join_1", - "auto_sortmerge_join_10", - "auto_sortmerge_join_11", - "auto_sortmerge_join_12", - "auto_sortmerge_join_13", - "auto_sortmerge_join_14", - "auto_sortmerge_join_15", - "auto_sortmerge_join_16", - "auto_sortmerge_join_2", - "auto_sortmerge_join_3", - "auto_sortmerge_join_4", - "auto_sortmerge_join_5", - "auto_sortmerge_join_6", - "auto_sortmerge_join_7", - "auto_sortmerge_join_8", - "auto_sortmerge_join_9", - "correlationoptimizer1", - "correlationoptimizer10", - "correlationoptimizer11", - "correlationoptimizer13", - "correlationoptimizer14", - "correlationoptimizer15", - "correlationoptimizer2", - "correlationoptimizer3", - "correlationoptimizer4", - "correlationoptimizer6", - "correlationoptimizer7", - "correlationoptimizer8", - "correlationoptimizer9", - "join0", - "join1", - "join10", - "join11", - "join12", - "join13", - "join14", - "join14_hadoop20", - "join15", - "join16", - "join17", - "join18", - "join19", - "join2", - "join20", - "join21", - "join22", - "join23", - "join24", - "join25", - "join26", - "join27", - "join28", - "join29", - "join3", - "join30", - "join31", - "join32", - "join32_lessSize", - "join33", - "join34", - "join35", - "join36", - "join37", - "join38", - "join39", - "join4", - "join40", - "join41", - "join5", - "join6", - "join7", - "join8", - "join9", - "join_1to1", - "join_array", - "join_casesensitive", - "join_empty", - "join_filters", - "join_hive_626", - "join_map_ppr", - "join_nulls", - "join_nullsafe", - "join_rc", - "join_reorder2", - "join_reorder3", - "join_reorder4", - "join_star" - ) - - // Only run those query tests in the realWhileList (do not try other ignored query files). - override def testCases: Seq[(String, File)] = super.testCases.filter { - case (name, _) => realWhiteList.contains(name) - } -} From 391e6be0ae883f3ea0fab79463eb8b618af79afb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Sep 2015 16:52:59 +0800 Subject: [PATCH 0498/1824] [SPARK-10301] [SQL] Fixes schema merging for nested structs This PR can be quite challenging to review. I'm trying to give a detailed description of the problem as well as its solution here. When reading Parquet files, we need to specify a potentially nested Parquet schema (of type `MessageType`) as requested schema for column pruning. This Parquet schema is translated from a Catalyst schema (of type `StructType`), which is generated by the query planner and represents all requested columns. However, this translation can be fairly complicated because of several reasons: 1. Requested schema must conform to the real schema of the physical file to be read. This means we have to tailor the actual file schema of every individual physical Parquet file to be read according to the given Catalyst schema. Fortunately we are already doing this in Spark 1.5 by pushing request schema conversion to executor side in PR #7231. 1. Support for schema merging. A single Parquet dataset may consist of multiple physical Parquet files come with different but compatible schemas. This means we may request for a column path that doesn't exist in a physical Parquet file. All requested column paths can be nested. For example, for a Parquet file schema ``` message root { required group f0 { required group f00 { required int32 f000; required binary f001 (UTF8); } } } ``` we may request for column paths defined in the following schema: ``` message root { required group f0 { required group f00 { required binary f001 (UTF8); required float f002; } } optional double f1; } ``` Notice that we pruned column path `f0.f00.f000`, but added `f0.f00.f002` and `f1`. The good news is that Parquet handles non-existing column paths properly and always returns null for them. 1. The map from `StructType` to `MessageType` is a one-to-many map. This is the most unfortunate part. Due to historical reasons (dark histories!), schemas of Parquet files generated by different libraries have different "flavors". For example, to handle a schema with a single non-nullable column, whose type is an array of non-nullable integers, parquet-protobuf generates the following Parquet schema: ``` message m0 { repeated int32 f; } ``` while parquet-avro generates another version: ``` message m1 { required group f (LIST) { repeated int32 array; } } ``` and parquet-thrift spills this: ``` message m1 { required group f (LIST) { repeated int32 f_tuple; } } ``` All of them can be mapped to the following _unique_ Catalyst schema: ``` StructType( StructField( "f", ArrayType(IntegerType, containsNull = false), nullable = false)) ``` This greatly complicates Parquet requested schema construction, since the path of a given column varies in different cases. To read the array elements from files with the above schemas, we must use `f` for `m0`, `f.array` for `m1`, and `f.f_tuple` for `m2`. In earlier Spark versions, we didn't try to fix this issue properly. Spark 1.4 and prior versions simply translate the Catalyst schema in a way more or less compatible with parquet-hive and parquet-avro, but is broken in many other cases. Earlier revisions of Spark 1.5 only try to tailor the Parquet file schema at the first level, and ignore nested ones. This caused [SPARK-10301] [spark-10301] as well as [SPARK-10005] [spark-10005]. In PR #8228, I tried to avoid the hard part of the problem and made a minimum change in `CatalystRowConverter` to fix SPARK-10005. However, when taking SPARK-10301 into consideration, keeping hacking `CatalystRowConverter` doesn't seem to be a good idea. So this PR is an attempt to fix the problem in a proper way. For a given physical Parquet file with schema `ps` and a compatible Catalyst requested schema `cs`, we use the following algorithm to tailor `ps` to get the result Parquet requested schema `ps'`: For a leaf column path `c` in `cs`: - if `c` exists in `cs` and a corresponding Parquet column path `c'` can be found in `ps`, `c'` should be included in `ps'`; - otherwise, we convert `c` to a Parquet column path `c"` using `CatalystSchemaConverter`, and include `c"` in `ps'`; - no other column paths should exist in `ps'`. Then comes the most tedious part: > Given `cs`, `ps`, and `c`, how to locate `c'` in `ps`? Unfortunately, there's no quick answer, and we have to enumerate all possible structures defined in parquet-format spec. They are: 1. the standard structure of nested types, and 1. cases defined in all backwards-compatibility rules for `LIST` and `MAP`. The core part of this PR is `CatalystReadSupport.clipParquetType()`, which tailors a given Parquet file schema according to a requested schema in its Catalyst form. Backwards-compatibility rules of `LIST` and `MAP` are covered in `clipParquetListType()` and `clipParquetMapType()` respectively. The column path selection algorithm is implemented in `clipParquetGroupFields()`. With this PR, we no longer need to do schema tailoring in `CatalystReadSupport` and `CatalystRowConverter`. Another benefit is that, now we can also read Parquet datasets consist of files with different physical Parquet schema but share the same logical schema, for example, files generated by different Parquet libraries. This situation is illustrated by [this test case] [test-case]. [spark-10301]: https://issues.apache.org/jira/browse/SPARK-10301 [spark-10005]: https://issues.apache.org/jira/browse/SPARK-10005 [test-case]: https://github.com/liancheng/spark/commit/38644d8a45175cbdf20d2ace021c2c2544a50ab3#diff-a9b98e28ce3ae30641829dffd1173be2R26 Author: Cheng Lian Closes #8509 from liancheng/spark-10301/fix-parquet-requested-schema. --- .../parquet/CatalystReadSupport.scala | 235 +++++++++---- .../parquet/CatalystRowConverter.scala | 51 +-- .../parquet/CatalystSchemaConverter.scala | 14 +- .../ParquetAvroCompatibilitySuite.scala | 1 + .../ParquetInteroperabilitySuite.scala | 90 +++++ .../parquet/ParquetQuerySuite.scala | 77 +++++ .../parquet/ParquetSchemaSuite.scala | 310 ++++++++++++++++++ 7 files changed, 653 insertions(+), 125 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 0a6bb44445f6..dc4ff06df6f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -19,17 +19,18 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.{Map => JMap} -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, mapAsScalaMapConverter} import org.apache.hadoop.conf.Configuration import org.apache.parquet.hadoop.api.ReadSupport.ReadContext import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} import org.apache.parquet.io.api.RecordMaterializer -import org.apache.parquet.schema.MessageType +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema._ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with Logging { // Called after `init()` when initializing Parquet record reader. @@ -81,70 +82,10 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with // `StructType` containing all requested columns. val maybeRequestedSchema = Option(conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) - // Below we construct a Parquet schema containing all requested columns. This schema tells - // Parquet which columns to read. - // - // If `maybeRequestedSchema` is defined, we assemble an equivalent Parquet schema. Otherwise, - // we have to fallback to the full file schema which contains all columns in the file. - // Obviously this may waste IO bandwidth since it may read more columns than requested. - // - // Two things to note: - // - // 1. It's possible that some requested columns don't exist in the target Parquet file. For - // example, in the case of schema merging, the globally merged schema may contain extra - // columns gathered from other Parquet files. These columns will be simply filled with nulls - // when actually reading the target Parquet file. - // - // 2. When `maybeRequestedSchema` is available, we can't simply convert the Catalyst schema to - // Parquet schema using `CatalystSchemaConverter`, because the mapping is not unique due to - // non-standard behaviors of some Parquet libraries/tools. For example, a Parquet file - // containing a single integer array field `f1` may have the following legacy 2-level - // structure: - // - // message root { - // optional group f1 (LIST) { - // required INT32 element; - // } - // } - // - // while `CatalystSchemaConverter` may generate a standard 3-level structure: - // - // message root { - // optional group f1 (LIST) { - // repeated group list { - // required INT32 element; - // } - // } - // } - // - // Apparently, we can't use the 2nd schema to read the target Parquet file as they have - // different physical structures. val parquetRequestedSchema = maybeRequestedSchema.fold(context.getFileSchema) { schemaString => - val toParquet = new CatalystSchemaConverter(conf) - val fileSchema = context.getFileSchema.asGroupType() - val fileFieldNames = fileSchema.getFields.asScala.map(_.getName).toSet - - StructType - // Deserializes the Catalyst schema of requested columns - .fromString(schemaString) - .map { field => - if (fileFieldNames.contains(field.name)) { - // If the field exists in the target Parquet file, extracts the field type from the - // full file schema and makes a single-field Parquet schema - new MessageType("root", fileSchema.getType(field.name)) - } else { - // Otherwise, just resorts to `CatalystSchemaConverter` - toParquet.convert(StructType(Array(field))) - } - } - // Merges all single-field Parquet schemas to form a complete schema for all requested - // columns. Note that it's possible that no columns are requested at all (e.g., count - // some partition column of a partitioned Parquet table). That's why `fold` is used here - // and always fallback to an empty Parquet schema. - .fold(new MessageType("root")) { - _ union _ - } + val catalystRequestedSchema = StructType.fromString(schemaString) + CatalystReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) } val metadata = @@ -160,4 +101,168 @@ private[parquet] object CatalystReadSupport { val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + /** + * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist + * in `catalystSchema`, and adding those only exist in `catalystSchema`. + */ + def clipParquetSchema(parquetSchema: MessageType, catalystSchema: StructType): MessageType = { + val clippedParquetFields = clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema) + Types.buildMessage().addFields(clippedParquetFields: _*).named("root") + } + + private def clipParquetType(parquetType: Type, catalystType: DataType): Type = { + catalystType match { + case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => + // Only clips array types with nested type as element type. + clipParquetListType(parquetType.asGroupType(), t.elementType) + + case t: MapType if !isPrimitiveCatalystType(t.valueType) => + // Only clips map types with nested type as value type. + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType) + + case t: StructType => + clipParquetGroup(parquetType.asGroupType(), t) + + case _ => + parquetType + } + } + + /** + * Whether a Catalyst [[DataType]] is primitive. Primitive [[DataType]] is not equivalent to + * [[AtomicType]]. For example, [[CalendarIntervalType]] is primitive, but it's not an + * [[AtomicType]]. + */ + private def isPrimitiveCatalystType(dataType: DataType): Boolean = { + dataType match { + case _: ArrayType | _: MapType | _: StructType => false + case _ => true + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[ArrayType]]. The element type + * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a + * [[StructType]]. + */ + private def clipParquetListType(parquetList: GroupType, elementType: DataType): Type = { + // Precondition of this method, should only be called for lists with nested element types. + assert(!isPrimitiveCatalystType(elementType)) + + // Unannotated repeated group should be interpreted as required list of required element, so + // list element type is just the group itself. Clip it. + if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { + clipParquetType(parquetList, elementType) + } else { + assert( + parquetList.getOriginalType == OriginalType.LIST, + "Invalid Parquet schema. " + + "Original type of annotated Parquet lists must be LIST: " + + parquetList.toString) + + assert( + parquetList.getFieldCount == 1 && parquetList.getType(0).isRepetition(Repetition.REPEATED), + "Invalid Parquet schema. " + + "LIST-annotated group should only have exactly one repeated field: " + + parquetList) + + // Precondition of this method, should only be called for lists with nested element types. + assert(!parquetList.getType(0).isPrimitive) + + val repeatedGroup = parquetList.getType(0).asGroupType() + + // If the repeated field is a group with multiple fields, or the repeated field is a group + // with one field and is named either "array" or uses the LIST-annotated group's name with + // "_tuple" appended then the repeated type is the element type and elements are required. + // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the + // only field. + if ( + repeatedGroup.getFieldCount > 1 || + repeatedGroup.getName == "array" || + repeatedGroup.getName == parquetList.getName + "_tuple" + ) { + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField(clipParquetType(repeatedGroup, elementType)) + .named(parquetList.getName) + } else { + // Otherwise, the repeated field's type is the element type with the repeated field's + // repetition. + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField( + Types + .repeatedGroup() + .addField(clipParquetType(repeatedGroup.getType(0), elementType)) + .named(repeatedGroup.getName)) + .named(parquetList.getName) + } + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. The value type + * of the [[MapType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a + * [[StructType]]. Note that key type of any [[MapType]] is always a primitive type. + */ + private def clipParquetMapType( + parquetMap: GroupType, keyType: DataType, valueType: DataType): GroupType = { + // Precondition of this method, should only be called for maps with nested value types. + assert(!isPrimitiveCatalystType(valueType)) + + val repeatedGroup = parquetMap.getType(0).asGroupType() + val parquetKeyType = repeatedGroup.getType(0) + val parquetValueType = repeatedGroup.getType(1) + + val clippedRepeatedGroup = + Types + .repeatedGroup() + .as(repeatedGroup.getOriginalType) + .addField(parquetKeyType) + .addField(clipParquetType(parquetValueType, valueType)) + .named(repeatedGroup.getName) + + Types + .buildGroup(parquetMap.getRepetition) + .as(parquetMap.getOriginalType) + .addField(clippedRepeatedGroup) + .named(parquetMap.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return A clipped [[GroupType]], which has at least one field. + * @note Parquet doesn't allow creating empty [[GroupType]] instances except for empty + * [[MessageType]]. Because it's legal to construct an empty requested schema for column + * pruning. + */ + private def clipParquetGroup(parquetRecord: GroupType, structType: StructType): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType) + Types + .buildGroup(parquetRecord.getRepetition) + .as(parquetRecord.getOriginalType) + .addFields(clippedParquetFields: _*) + .named(parquetRecord.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return A list of clipped [[GroupType]] fields, which can be empty. + */ + private def clipParquetGroupFields( + parquetRecord: GroupType, structType: StructType): Seq[Type] = { + val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + val toParquet = new CatalystSchemaConverter(followParquetFormatSpec = true) + structType.map { f => + parquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType)) + .getOrElse(toParquet.convertField(f)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index fe13dfbbed38..f17e794b7665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -113,31 +113,6 @@ private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUp * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have * any "parent" container. * - * @note Constructor argument [[parquetType]] refers to requested fields of the actual schema of the - * Parquet file being read, while constructor argument [[catalystType]] refers to requested - * fields of the global schema. The key difference is that, in case of schema merging, - * [[parquetType]] can be a subset of [[catalystType]]. For example, it's possible to have - * the following [[catalystType]]: - * {{{ - * new StructType() - * .add("f1", IntegerType, nullable = false) - * .add("f2", StringType, nullable = true) - * .add("f3", new StructType() - * .add("f31", DoubleType, nullable = false) - * .add("f32", IntegerType, nullable = true) - * .add("f33", StringType, nullable = true), nullable = false) - * }}} - * and the following [[parquetType]] (`f2` and `f32` are missing): - * {{{ - * message root { - * required int32 f1; - * required group f3 { - * required double f31; - * optional binary f33 (utf8); - * } - * } - * }}} - * * @param parquetType Parquet schema of Parquet records * @param catalystType Spark SQL schema that corresponds to the Parquet record type * @param updater An updater which propagates converted field values to the parent container @@ -179,31 +154,7 @@ private[parquet] class CatalystRowConverter( // Converters for each field. private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { - // In case of schema merging, `parquetType` can be a subset of `catalystType`. We need to pad - // those missing fields and create converters for them, although values of these fields are - // always null. - val paddedParquetFields = { - val parquetFields = parquetType.getFields.asScala - val parquetFieldNames = parquetFields.map(_.getName).toSet - val missingFields = catalystType.filterNot(f => parquetFieldNames.contains(f.name)) - - // We don't need to worry about feature flag arguments like `assumeBinaryIsString` when - // creating the schema converter here, since values of missing fields are always null. - val toParquet = new CatalystSchemaConverter() - - (parquetFields ++ missingFields.map(toParquet.convertField)).sortBy { f => - catalystType.indexWhere(_.name == f.getName) - } - } - - if (paddedParquetFields.length != catalystType.length) { - throw new UnsupportedOperationException( - "A Parquet file's schema has different number of fields with the table schema. " + - "Please enable schema merging by setting \"mergeSchema\" to true when load " + - "a Parquet dataset or set spark.sql.parquet.mergeSchema to true in SQLConf.") - } - - paddedParquetFields.zip(catalystType).zipWithIndex.map { + parquetType.getFields.asScala.zip(catalystType).zipWithIndex.map { case ((parquetFieldType, catalystField), ordinal) => // Converted field value should be set to the `ordinal`-th cell of `currentRow` newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index be6c0545f5a0..a21ab1dbb25d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -55,16 +55,10 @@ import org.apache.spark.sql.{AnalysisException, SQLConf} * to old style non-standard behaviors. */ private[parquet] class CatalystSchemaConverter( - private val assumeBinaryIsString: Boolean, - private val assumeInt96IsTimestamp: Boolean, - private val followParquetFormatSpec: Boolean) { - - // Only used when constructing converter for converting Spark SQL schema to Parquet schema, in - // which case `assumeInt96IsTimestamp` and `assumeBinaryIsString` are irrelevant. - def this() = this( - assumeBinaryIsString = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, - assumeInt96IsTimestamp = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, - followParquetFormatSpec = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get) + assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + followParquetFormatSpec: Boolean = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get +) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index bd7cf8c10abe..36b929ee1f40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.io.File import java.nio.ByteBuffer import java.util.{List => JList, Map => JMap} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala new file mode 100644 index 000000000000..83b65fb419ed --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + test("parquet files with different physical schemas but share the same logical schema") { + import ParquetCompatibilityTest._ + + // This test case writes two Parquet files, both representing the following Catalyst schema + // + // StructType( + // StructField( + // "f", + // ArrayType(IntegerType, containsNull = false), + // nullable = false)) + // + // The first Parquet file comes with parquet-avro style 2-level LIST-annotated group, while the + // other one comes with parquet-protobuf style 1-level unannotated primitive field. + withTempDir { dir => + val avroStylePath = new File(dir, "avro-style").getCanonicalPath + val protobufStylePath = new File(dir, "protobuf-style").getCanonicalPath + + val avroStyleSchema = + """message avro_style { + | required group f (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin + + writeDirect(avroStylePath, avroStyleSchema, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("array", 0) { + rc.addInteger(0) + rc.addInteger(1) + } + } + } + } + }) + + logParquetSchema(avroStylePath) + + val protobufStyleSchema = + """message protobuf_style { + | repeated int32 f; + |} + """.stripMargin + + writeDirect(protobufStylePath, protobufStyleSchema, { rc => + rc.message { + rc.field("f", 0) { + rc.addInteger(2) + rc.addInteger(3) + } + } + }) + + logParquetSchema(protobufStylePath) + + checkAnswer( + sqlContext.read.parquet(dir.getCanonicalPath), + Seq( + Row(Seq(0, 1)), + Row(Seq(2, 3)))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index b7b70c2bbbd5..a379523d67f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -229,4 +229,81 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("SPARK-10301 Clipping nested structs in requested schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .coalesce(1) + + df.write.mode("append").parquet(path) + + val userDefinedSchema = new StructType() + .add("s", new StructType().add("a", LongType, nullable = true), nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('b', id, 'c', id) AS s") + .coalesce(1) + + df1.write.parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add("a", LongType, nullable = true) + .add("c", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Seq( + Row(Row(0, null)), + Row(Row(null, 1)))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', ARRAY(NAMED_STRUCT('b', id, 'c', id))) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add( + "a", + ArrayType( + new StructType() + .add("b", LongType, nullable = true) + .add("d", StringType, nullable = true), + containsNull = true), + nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(Seq(Row(0, null))))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 9dcbc1a047be..28c59a4abdd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -941,4 +942,313 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); |} """.stripMargin) + + private def testSchemaClipping( + testName: String, + parquetSchema: String, + catalystSchema: StructType, + expectedSchema: String): Unit = { + test(s"Clipping - $testName") { + val expected = MessageTypeParser.parseMessageType(expectedSchema) + val actual = CatalystReadSupport.clipParquetSchema( + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema) + + try { + expected.checkContains(actual) + actual.checkContains(expected) + } catch { case cause: Throwable => + fail( + s"""Expected clipped schema: + |$expected + |Actual clipped schema: + |$actual + """.stripMargin, + cause) + } + } + } + + testSchemaClipping( + "simple nested struct", + + parquetSchema = + """message root { + | required group f0 { + | optional int32 f00; + | optional int32 f01; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType().add("f00", IntegerType, nullable = true) + new StructType() + .add("f0", f0Type, nullable = false) + .add("f1", IntegerType, nullable = true) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional int32 f00; + | } + | optional int32 f1; + |} + """.stripMargin) + + testSchemaClipping( + "parquet-protobuf style array", + + parquetSchema = + """message root { + | required group f0 { + | repeated binary f00 (UTF8); + | repeated group f01 { + | optional int32 f010; + | optional double f011; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f11Type = new StructType().add("f011", DoubleType, nullable = true) + val f01Type = ArrayType(StringType, containsNull = false) + val f0Type = new StructType() + .add("f00", f01Type, nullable = false) + .add("f01", f11Type, nullable = false) + val f1Type = ArrayType(IntegerType, containsNull = true) + new StructType() + .add("f0", f0Type, nullable = false) + .add("f1", f1Type, nullable = true) + }, + + expectedSchema = + """message root { + | required group f0 { + | repeated binary f00 (UTF8); + | repeated group f01 { + | optional double f011; + | } + | } + | + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-thrift style array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 { + | repeated binary f00_tuple (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group f01_tuple { + | optional int32 f010; + | optional double f011; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f11ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = false) + .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 { + | repeated binary f00_tuple (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group f01_tuple { + | optional double f011; + | optional int64 f012; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-avro style array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 { + | repeated binary array (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group array { + | optional int32 f010; + | optional double f011; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f11ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = false) + .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 { + | repeated binary array (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group array { + | optional double f011; + | optional int64 f012; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-hive style array", + + parquetSchema = + """message root { + | optional group f0 { + | optional group f00 (LIST) { + | repeated group bag { + | optional binary array_element; + | } + | } + | + | optional group f01 (LIST) { + | repeated group bag { + | optional group array_element { + | optional int32 f010; + | optional double f011; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = true), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = true), nullable = true) + + new StructType().add("f0", f0Type, nullable = true) + }, + + expectedSchema = + """message root { + | optional group f0 { + | optional group f00 (LIST) { + | repeated group bag { + | optional binary array_element; + | } + | } + | + | optional group f01 (LIST) { + | repeated group bag { + | optional group array_element { + | optional double f011; + | optional int64 f012; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "2-level list of required struct", + + parquetSchema = + s"""message root { + | required group f0 { + | required group f00 (LIST) { + | repeated group element { + | required int32 f000; + | optional int64 f001; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f00ElementType = + new StructType() + .add("f001", LongType, nullable = true) + .add("f002", DoubleType, nullable = false) + + val f00Type = ArrayType(f00ElementType, containsNull = false) + val f0Type = new StructType().add("f00", f00Type, nullable = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + s"""message root { + | required group f0 { + | required group f00 (LIST) { + | repeated group element { + | optional int64 f001; + | required double f002; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "empty requested schema", + + parquetSchema = + """message root { + | required group f0 { + | required int32 f00; + | required int64 f01; + | } + |} + """.stripMargin, + + catalystSchema = new StructType(), + + expectedSchema = "message root {}") } From e6e483cc4de740c46398385b03ffe0e662edae39 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 1 Sep 2015 10:48:57 -0700 Subject: [PATCH 0499/1824] [SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover Add a python API for the Stop Words Remover. Author: Holden Karau Closes #8118 from holdenk/SPARK-9679-python-StopWordsRemover. --- .../spark/ml/feature/StopWordsRemover.scala | 6 +- .../ml/feature/StopWordsRemoverSuite.scala | 2 +- python/pyspark/ml/feature.py | 73 ++++++++++++++++++- python/pyspark/ml/tests.py | 20 ++++- 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 5d77ea08db65..7da430c7d16d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp /** * stop words list */ -private object StopWords { +private[spark] object StopWords { /** * Use the same default stopwords list as scikit-learn. * The original list can be found from "Glasgow Information Retrieval Group" * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] */ - val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again", + val English = Array( "a", "about", "above", "across", "after", "afterwards", "again", "against", "all", "almost", "alone", "along", "already", "also", "although", "always", "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", @@ -121,7 +121,7 @@ class StopWordsRemover(override val uid: String) /** @group getParam */ def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false) + setDefault(stopWords -> StopWords.English, caseSensitive -> false) override def transform(dataset: DataFrame): DataFrame = { val outputSchema = transformSchema(dataset.schema) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index f01306f89cb5..e0d433f566c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { } test("StopWordsRemover with additional words") { - val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala") + val stopWords = StopWords.English ++ Array("python", "scala") val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 0626281e200a..d955307e27ef 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -22,7 +22,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector @@ -30,7 +30,7 @@ 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', - 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] + 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover'] @inherit_doc @@ -933,6 +933,75 @@ class StringIndexerModel(JavaModel): """ +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A feature transformer that filters out stop words from input. + Note: null values from input array are preserved unless adding null to stopWords explicitly. + """ + # a placeholder to make the stopwords show up in generated doc + stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") + caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + + "comparison over the stop words") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): + """ + __init__(self, inputCol=None, outputCol=None, stopWords=None,\ + caseSensitive=false) + """ + super(StopWordsRemover, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", + self.uid) + self.stopWords = Param(self, "stopWords", "The words to be filtered out") + self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " + + "sensitive comparison over the stop words") + stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords + defaultStopWords = stopWordsObj.English() + self._setDefault(stopWords=defaultStopWords) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): + """ + setParams(self, inputCol="input", outputCol="output", stopWords=None,\ + caseSensitive=false) + Sets params for this StopWordRemover. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setStopWords(self, value): + """ + Specify the stopwords to be filtered. + """ + self._paramMap[self.stopWords] = value + return self + + def getStopWords(self): + """ + Get the stopwords. + """ + return self.getOrDefault(self.stopWords) + + def setCaseSensitive(self, value): + """ + Set whether to do a case sensitive comparison over the stop words + """ + self._paramMap[self.caseSensitive] = value + return self + + def getCaseSensitive(self): + """ + Get whether to do a case sensitive comparison over the stop words. + """ + return self.getOrDefault(self.caseSensitive) + + @inherit_doc @ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 60e4237293ad..b892318f50bd 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -31,7 +31,7 @@ import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext +from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params @@ -258,7 +258,7 @@ def test_idf(self): def test_ngram(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([ - ([["a", "b", "c", "d", "e"]])], ["input"]) + Row(input=["a", "b", "c", "d", "e"])]) ngram0 = NGram(n=4, inputCol="input", outputCol="output") self.assertEqual(ngram0.getN(), 4) self.assertEqual(ngram0.getInputCol(), "input") @@ -266,6 +266,22 @@ def test_ngram(self): transformedDF = ngram0.transform(dataset) self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + def test_stopwordsremover(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + # Default + self.assertEquals(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["panda"]) + # Custom + stopwords = ["panda"] + stopWordRemover.setStopWords(stopwords) + self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEquals(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a"]) + class HasInducedError(Params): From 3f63bd6023edcc9af268933a235f34e10bc3d2ba Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 1 Sep 2015 20:06:01 +0100 Subject: [PATCH 0500/1824] [SPARK-10398] [DOCS] Migrate Spark download page to use new lua mirroring scripts Migrate Apache download closer.cgi refs to new closer.lua This is the bit of the change that affects the project docs; I'm implementing the changes to the Apache site separately. Author: Sean Owen Closes #8557 from srowen/SPARK-10398. --- docker/spark-mesos/Dockerfile | 2 +- docs/running-on-mesos.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/spark-mesos/Dockerfile b/docker/spark-mesos/Dockerfile index b90aef3655de..fb3f267fe5c7 100644 --- a/docker/spark-mesos/Dockerfile +++ b/docker/spark-mesos/Dockerfile @@ -24,7 +24,7 @@ RUN apt-get update && \ apt-get install -y python libnss3 openjdk-7-jre-headless curl RUN mkdir /opt/spark && \ - curl http://www.apache.org/dyn/closer.cgi/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ + curl http://www.apache.org/dyn/closer.lua/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ | tar -xzC /opt ENV SPARK_HOME /opt/spark ENV MESOS_NATIVE_JAVA_LIBRARY /usr/local/lib/libmesos.so diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index cfd219ab02e2..f36921ae30c2 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -45,7 +45,7 @@ frameworks. You can install Mesos either from source or using prebuilt packages To install Apache Mesos from source, follow these steps: 1. Download a Mesos release from a - [mirror](http://www.apache.org/dyn/closer.cgi/mesos/{{site.MESOS_VERSION}}/) + [mirror](http://www.apache.org/dyn/closer.lua/mesos/{{site.MESOS_VERSION}}/) 2. Follow the Mesos [Getting Started](http://mesos.apache.org/gettingstarted) page for compiling and installing Mesos From ec012805337926e56343be2761a1037296446880 Mon Sep 17 00:00:00 2001 From: zhuol Date: Tue, 1 Sep 2015 11:14:59 -1000 Subject: [PATCH 0501/1824] [SPARK-4223] [CORE] Support * in acls. SPARK-4223. Currently we support setting view and modify acls but you have to specify a list of users. It would be nice to support * meaning all users have access. Manual tests to verify that: "*" works for any user in: a. Spark ui: view and kill stage. Done. b. Spark history server. Done. c. Yarn application killing. Done. Author: zhuol Closes #8398 from zhuoliu/4223. --- .../org/apache/spark/SecurityManager.scala | 26 ++++++++++-- .../apache/spark/SecurityManagerSuite.scala | 41 +++++++++++++++++++ docs/configuration.md | 9 ++-- 3 files changed, 69 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 673ef49e7c1c..746d2081d439 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -310,7 +310,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) setViewAcls(Set[String](defaultUser), allowedUsers) } - def getViewAcls: String = viewAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getViewAcls: String = { + if (viewAcls.contains("*")) { + "*" + } else { + viewAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -321,7 +330,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) } - def getModifyAcls: String = modifyAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getModifyAcls: String = { + if (modifyAcls.contains("*")) { + "*" + } else { + modifyAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -394,7 +412,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkUIViewPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - !aclsEnabled || user == null || viewAcls.contains(user) + !aclsEnabled || user == null || viewAcls.contains(user) || viewAcls.contains("*") } /** @@ -409,7 +427,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkModifyPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + modifyAcls.mkString(",")) - !aclsEnabled || user == null || modifyAcls.contains(user) + !aclsEnabled || user == null || modifyAcls.contains(user) || modifyAcls.contains("*") } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index f34aefca4eb1..f29160d83408 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -125,6 +125,47 @@ class SecurityManagerSuite extends SparkFunSuite { } + test("set security with * in acls") { + val conf = new SparkConf + conf.set("spark.ui.acls.enable", "true") + conf.set("spark.admin.acls", "user1,user2") + conf.set("spark.ui.view.acls", "*") + conf.set("spark.modify.acls", "user4") + + val securityManager = new SecurityManager(conf) + assert(securityManager.aclsEnabled() === true) + + // check for viewAcls with * + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user4") === true) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for modifyAcls with * + securityManager.setModifyAcls(Set("user4"), "*") + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + + securityManager.setAdminAcls("user1,user2") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === false) + assert(securityManager.checkUIViewPermissions("user6") === false) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for adminAcls with * + securityManager.setAdminAcls("user1,*") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + } + test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() val expectedAlgorithms = Set( diff --git a/docs/configuration.md b/docs/configuration.md index 77c5cbc7b319..fb0315ce7c3c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1286,7 +1286,8 @@ Apart from these, the following properties are also available, and may be useful Comma separated list of users/administrators that have view and modify access to all Spark jobs. This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. + help debug when things work. Putting a "*" in the list means any user can have the priviledge + of admin. @@ -1327,7 +1328,8 @@ Apart from these, the following properties are also available, and may be useful Empty Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). + user that started the Spark job has access to modify it (kill it for example). Putting a "*" in + the list means any user can have access to modify it. @@ -1349,7 +1351,8 @@ Apart from these, the following properties are also available, and may be useful Empty Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. + user that started the Spark job has view access. Putting a "*" in the list means any user can + have view access to this Spark job. From bf550a4b551b6dd18fea3eb3f70497f9a6ad8e6c Mon Sep 17 00:00:00 2001 From: 0x0FFF Date: Tue, 1 Sep 2015 14:34:59 -0700 Subject: [PATCH 0502/1824] [SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162) The issue is with DataFrame filter() function, if datetime.datetime is passed to it: * Timezone information of this datetime is ignored * This datetime is assumed to be in local timezone, which depends on the OS timezone setting Fix includes both code change and regression test. Problem reproduction code on master: ```python import pytz from datetime import datetime from pyspark.sql import * from pyspark.sql.types import * sqc = SQLContext(sc) df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())])) m1 = pytz.timezone('UTC') m2 = pytz.timezone('Etc/GMT+3') df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain() df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain() ``` It gives the same timestamp ignoring time zone: ``` >>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain() Filter (dt#0 > 946713600000000) Scan PhysicalRDD[dt#0] >>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain() Filter (dt#0 > 946713600000000) Scan PhysicalRDD[dt#0] ``` After the fix: ``` >>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain() Filter (dt#0 > 946684800000000) Scan PhysicalRDD[dt#0] >>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain() Filter (dt#0 > 946695600000000) Scan PhysicalRDD[dt#0] ``` PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo Author: 0x0FFF Closes #8555 from 0x0FFF/SPARK-10162. --- python/pyspark/sql/tests.py | 26 ++++++++++++++++++-------- python/pyspark/sql/types.py | 7 +++++-- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cd32e26c64f2..59a891bd7c42 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -50,16 +50,17 @@ from pyspark.sql.utils import AnalysisException, IllegalArgumentException -class UTC(datetime.tzinfo): - """UTC""" - ZERO = datetime.timedelta(0) +class UTCOffsetTimezone(datetime.tzinfo): + """ + Specifies timezone in UTC offset + """ + + def __init__(self, offset=0): + self.ZERO = datetime.timedelta(hours=offset) def utcoffset(self, dt): return self.ZERO - def tzname(self, dt): - return "UTC" - def dst(self, dt): return self.ZERO @@ -841,13 +842,22 @@ def test_filter_with_datetime(self): self.assertEqual(0, df.filter(df.date > date).count()) self.assertEqual(0, df.filter(df.time > time).count()) + def test_filter_with_datetime_timezone(self): + dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) + dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) + row = Row(date=dt1) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(0, df.filter(df.date == dt2).count()) + self.assertEqual(1, df.filter(df.date > dt2).count()) + self.assertEqual(0, df.filter(df.date < dt2).count()) + def test_time_with_timezone(self): day = datetime.date.today() now = datetime.datetime.now() ts = time.mktime(now.timetuple()) # class in __main__ is not serializable - from pyspark.sql.tests import UTC - utc = UTC() + from pyspark.sql.tests import UTCOffsetTimezone + utc = UTCOffsetTimezone() utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds # add microseconds to utcnow (keeping year,month,day,hour,minute,second) utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 94e581a78364..f84d08d7098a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1290,8 +1290,11 @@ def can_convert(self, obj): def convert(self, obj, gateway_client): Timestamp = JavaClass("java.sql.Timestamp", gateway_client) - return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) - + seconds = (calendar.timegm(obj.utctimetuple()) if obj.tzinfo + else time.mktime(obj.timetuple())) + t = Timestamp(int(seconds) * 1000) + t.setNanos(obj.microsecond * 1000) + return t # datetime is a subclass of date, we should register DatetimeConverter first register_input_converter(DatetimeConverter()) From 00d9af5e190475affffb8b50467fcddfc40f50dc Mon Sep 17 00:00:00 2001 From: 0x0FFF Date: Tue, 1 Sep 2015 14:58:49 -0700 Subject: [PATCH 0503/1824] [SPARK-10392] [SQL] Pyspark - Wrong DateType support on JDBC connection This PR addresses issue [SPARK-10392](https://issues.apache.org/jira/browse/SPARK-10392) The problem is that for "start of epoch" date (01 Jan 1970) PySpark class DateType returns 0 instead of the `datetime.date` due to implementation of its return statement Issue reproduction on master: ``` >>> from pyspark.sql.types import * >>> a = DateType() >>> a.fromInternal(0) 0 >>> a.fromInternal(1) datetime.date(1970, 1, 2) ``` Author: 0x0FFF Closes #8556 from 0x0FFF/SPARK-10392. --- python/pyspark/sql/tests.py | 5 +++++ python/pyspark/sql/types.py | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 59a891bd7c42..fc778631d93a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -168,6 +168,11 @@ def test_decimal_type(self): t3 = DecimalType(8) self.assertNotEqual(t2, t3) + # regression test for SPARK-10392 + def test_datetype_equal_zero(self): + dt = DateType() + self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) + class SQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f84d08d7098a..8bd58d69eeec 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -168,10 +168,12 @@ def needConversion(self): return True def toInternal(self, d): - return d and d.toordinal() - self.EPOCH_ORDINAL + if d is not None: + return d.toordinal() - self.EPOCH_ORDINAL def fromInternal(self, v): - return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + if v is not None: + return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) class TimestampType(AtomicType): From c3b881a7d7e4736f7131ff002a80e25def1f63af Mon Sep 17 00:00:00 2001 From: Chuan Shao Date: Wed, 2 Sep 2015 11:02:27 -0700 Subject: [PATCH 0504/1824] [SPARK-7336] [HISTORYSERVER] Fix bug that applications status incorrect on JobHistory UI. Author: ArcherShao Closes #5886 from ArcherShao/SPARK-7336. --- .../deploy/history/FsHistoryProvider.scala | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e573ff16c50a..a5755eac3639 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} +import java.util.UUID import java.util.concurrent.{ExecutorService, Executors, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -73,7 +74,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. This is used // to ignore logs that are older during subsequent scans, to avoid processing data that // is already known. - private var lastModifiedTime = -1L + private var lastScanTime = -1L // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -179,15 +180,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { + val newLastScanTime = getNewLastScanTime() val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - var newLastModifiedTime = lastModifiedTime val logInfos: Seq[FileStatus] = statusList .filter { entry => try { getModificationTime(entry).map { time => - newLastModifiedTime = math.max(newLastModifiedTime, time) - time >= lastModifiedTime + time >= lastScanTime }.getOrElse(false) } catch { case e: AccessControlException => @@ -224,12 +224,29 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - lastModifiedTime = newLastModifiedTime + lastScanTime = newLastScanTime } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } + private def getNewLastScanTime(): Long = { + val fileName = "." + UUID.randomUUID().toString + val path = new Path(logDir, fileName) + val fos = fs.create(path) + + try { + fos.close() + fs.getFileStatus(path).getModificationTime + } catch { + case e: Exception => + logError("Exception encountered when attempting to update last scan time", e) + lastScanTime + } finally { + fs.delete(path) + } + } + override def writeEventLogs( appId: String, attemptId: Option[String], From 56c4c172e99a5e14f4bc3308e7ff36d94113b63e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 2 Sep 2015 11:13:17 -0700 Subject: [PATCH 0505/1824] [SPARK-10034] [SQL] add regression test for Sort on Aggregate Before #8371, there was a bug for `Sort` on `Aggregate` that we can't use aggregate expressions named `_aggOrdering` and can't use more than one ordering expressions which contains aggregate functions. The reason of this bug is that: The aggregate expression in `SortOrder` never get resolved, we alias it with `_aggOrdering` and call `toAttribute` which gives us an `UnresolvedAttribute`. So actually we are referencing aggregate expression by name, not by exprId like we thought. And if there is already an aggregate expression named `_aggOrdering` or there are more than one ordering expressions having aggregate functions, we will have conflict names and can't search by name. However, after #8371 got merged, the `SortOrder`s are guaranteed to be resolved and we are always referencing aggregate expression by exprId. The Bug doesn't exist anymore and this PR add regression tests for it. Author: Wenchen Fan Closes #8231 from cloud-fan/sort-agg. --- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 284fff184085..a4871e247cff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -887,4 +887,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .select(struct($"b")) .collect() } + + test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { + val df = Seq(1 -> 2).toDF("i", "j") + val query = df.groupBy('i) + .agg(max('j).as("aggOrdering")) + .orderBy(sum('j)) + checkAnswer(query, Row(1, 2)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9e172b2c264c..28201073a2d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1490,6 +1490,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b), max(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( sql( """ From fc48307797912dc1d53893dce741ddda8630957b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 2 Sep 2015 11:32:27 -0700 Subject: [PATCH 0506/1824] [SPARK-10389] [SQL] support order by non-attribute grouping expression on Aggregate For example, we can write `SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1` in PostgreSQL, and we should support this in Spark SQL. Author: Wenchen Fan Closes #8548 from cloud-fan/support-order-by-non-attribute. --- .../sql/catalyst/analysis/Analyzer.scala | 72 ++++++++++--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 19 +++-- 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1a5de15c61f8..591747b45c37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -560,43 +560,47 @@ class Analyzer( filter } - case sort @ Sort(sortOrder, global, - aggregate @ Aggregate(grouping, originalAggExprs, child)) + case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved && !sort.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { - val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")()) - val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child) - val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] - def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions - - // Expressions that have an aggregate can be pushed down. - val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate) - - // Attribute references, that are missing from the order but are present in the grouping - // expressions can also be pushed down. - val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _) - val missingAttributes = requiredAttributes -- aggregate.outputSet - val validPushdownAttributes = - missingAttributes.filter(a => grouping.exists(a.semanticEquals)) - - // If resolution was successful and we see the ordering either has an aggregate in it or - // it is missing something that is projected away by the aggregate, add the ordering - // the original aggregate operator. - if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) { - val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map { - case (order, evaluated) => order.copy(child = evaluated.toAttribute) - } - val aggExprsWithOrdering: Seq[NamedExpression] = - resolvedAggregateOrdering ++ originalAggExprs - - Project(aggregate.output, - Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = aggExprsWithOrdering))) - } else { - sort + val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) + val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAliasedOrdering: Seq[Alias] = + resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] + + // If we pass the analysis check, then the ordering expressions should only reference to + // aggregate expressions or grouping expressions, and it's safe to push them down to + // Aggregate. + checkAnalysis(resolvedAggregate) + + val originalAggExprs = aggregate.aggregateExpressions.map( + CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + + // If the ordering expression is same with original aggregate expression, we don't need + // to push down this ordering expression and can reference the original aggregate + // expression instead. + val needsPushDown = ArrayBuffer.empty[NamedExpression] + val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + case (evaluated, order) => + val index = originalAggExprs.indexWhere { + case Alias(child, _) => child semanticEquals evaluated.child + case other => other semanticEquals evaluated.child + } + + if (index == -1) { + needsPushDown += evaluated + order.copy(child = evaluated.toAttribute) + } else { + order.copy(child = originalAggExprs(index).toAttribute) + } } + + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. @@ -605,9 +609,7 @@ class Analyzer( } protected def containsAggregate(condition: Expression): Boolean = { - condition - .collect { case ae: AggregateExpression => ae } - .nonEmpty + condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 28201073a2d7..0ef25fe0faef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1722,9 +1722,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10130 type coercion for IF should have children resolved first") { - val df = Seq((1, 1), (-1, 1)).toDF("key", "value") - df.registerTempTable("src") - checkAnswer( - sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + } + } + + test("SPARK-10389: order by non-attribute grouping expression on Aggregate") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"), + Seq(Row(1), Row(1))) + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"), + Seq(Row(1), Row(1))) + } } } From 2da3a9e98e5d129d4507b5db01bba5ee9558d28e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 2 Sep 2015 12:53:24 -0700 Subject: [PATCH 0507/1824] [SPARK-10004] [SHUFFLE] Perform auth checks when clients read shuffle data. To correctly isolate applications, when requests to read shuffle data arrive at the shuffle service, proper authorization checks need to be performed. This change makes sure that only the application that created the shuffle data can read from it. Such checks are only enabled when "spark.authenticate" is enabled, otherwise there's no secure way to make sure that the client is really who it says it is. Author: Marcelo Vanzin Closes #8218 from vanzin/SPARK-10004. --- .../network/netty/NettyBlockRpcServer.scala | 3 +- .../netty/NettyBlockTransferService.scala | 2 +- network/common/pom.xml | 4 + .../spark/network/client/TransportClient.java | 22 +++ .../network/sasl/SaslClientBootstrap.java | 2 + .../spark/network/sasl/SaslRpcHandler.java | 1 + .../server/OneForOneStreamManager.java | 31 +++- .../spark/network/server/StreamManager.java | 9 + .../server/TransportRequestHandler.java | 1 + .../shuffle/ExternalShuffleBlockHandler.java | 16 +- .../network/sasl/SaslIntegrationSuite.java | 163 +++++++++++++++--- .../ExternalShuffleBlockHandlerSuite.java | 2 +- project/MimaExcludes.scala | 1 + 13 files changed, 221 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 7c170a742fb6..76968249fb62 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} * is equivalent to one Spark-level shuffle block. */ class NettyBlockRpcServer( + appId: String, serializer: Serializer, blockManager: BlockDataManager) extends RpcHandler with Logging { @@ -55,7 +56,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator.asJava) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index ff8aae9ebe9f..d5ad2c9ad00e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { diff --git a/network/common/pom.xml b/network/common/pom.xml index 7dc3068ab8cb..4141fcb8267a 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -48,6 +48,10 @@ slf4j-api provided
    + + com.google.code.findbugs + jsr305 + {% highlight python %} from pyspark.ml.regression import LinearRegression -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils # Load training data -training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) From b656e6134fc5cd27e1fe6b6ab30fd7633cab0b14 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Sep 2015 08:50:35 -0700 Subject: [PATCH 0589/1824] [SPARK-10026] [ML] [PySpark] Implement some common Params for regression in PySpark LinearRegression and LogisticRegression lack of some Params for Python, and some Params are not shared classes which lead we need to write them for each class. These kinds of Params are list here: ```scala HasElasticNetParam HasFitIntercept HasStandardization HasThresholds ``` Here we implement them in shared params at Python side and make LinearRegression/LogisticRegression parameters peer with Scala one. Author: Yanbo Liang Closes #8508 from yanboliang/spark-10026. --- python/pyspark/ml/classification.py | 75 ++---------- .../ml/param/_shared_params_code_gen.py | 11 +- python/pyspark/ml/param/shared.py | 111 ++++++++++++++++++ python/pyspark/ml/regression.py | 42 ++----- 4 files changed, 143 insertions(+), 96 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 83f808efc3bf..22bdd1b322ac 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -31,7 +31,8 @@ @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): + HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): """ Logistic regression. Currently, this class only supports binary classification. @@ -65,17 +66,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti """ # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") - thresholds = Param(Params._dummy(), "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") @@ -83,40 +73,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification, in range [0, 1]. self.threshold = Param(self, "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") - #: param for thresholds or cutoffs in binary or multiclass classification - self.thresholds = \ - Param(self, "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") - self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True, threshold=0.5) + self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -124,13 +97,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ @@ -142,32 +115,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LogisticRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - - def setFitIntercept(self, value): - """ - Sets the value of :py:attr:`fitIntercept`. - """ - self._paramMap[self.fitIntercept] = value - return self - - def getFitIntercept(self): - """ - Gets the value of fitIntercept or its default value. - """ - return self.getOrDefault(self.fitIntercept) - def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 926375e44871..5b39e5dd4e25 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -124,7 +124,16 @@ def get$Name(self): ("stepSize", "Step size to be used for each iteration of optimization.", None), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None)] + "added later.", None), + ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"), + ("fitIntercept", "whether to fit an intercept term.", "True"), + ("standardization", "whether to standardize the training features before fitting the " + + "model.", "True"), + ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + + "predicting each class. Array must have length equal to the number of classes, with " + + "values >= 0. The class with largest value p/t is predicted, where p is the original " + + "probability of that class and t is the class' threshold.", None)] code = [] for name, doc, defaultValueStr in shared: param_code = _gen_param_header(name, doc, defaultValueStr) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 682170aee85f..af1218128602 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -459,6 +459,117 @@ def getHandleInvalid(self): return self.getOrDefault(self.handleInvalid) +class HasElasticNetParam(Params): + """ + Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + """ + + # a placeholder to make it appear in the generated doc + elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + + def __init__(self): + super(HasElasticNetParam, self).__init__() + #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + self._setDefault(elasticNetParam=0.0) + + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + + +class HasFitIntercept(Params): + """ + Mixin for param fitIntercept: whether to fit an intercept term.. + """ + + # a placeholder to make it appear in the generated doc + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + + def __init__(self): + super(HasFitIntercept, self).__init__() + #: param for whether to fit an intercept term. + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + self._setDefault(fitIntercept=True) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + self._paramMap[self.fitIntercept] = value + return self + + def getFitIntercept(self): + """ + Gets the value of fitIntercept or its default value. + """ + return self.getOrDefault(self.fitIntercept) + + +class HasStandardization(Params): + """ + Mixin for param standardization: whether to standardize the training features before fitting the model.. + """ + + # a placeholder to make it appear in the generated doc + standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.") + + def __init__(self): + super(HasStandardization, self).__init__() + #: param for whether to standardize the training features before fitting the model. + self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.") + self._setDefault(standardization=True) + + def setStandardization(self, value): + """ + Sets the value of :py:attr:`standardization`. + """ + self._paramMap[self.standardization] = value + return self + + def getStandardization(self): + """ + Gets the value of standardization or its default value. + """ + return self.getOrDefault(self.standardization) + + +class HasThresholds(Params): + """ + Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.. + """ + + # a placeholder to make it appear in the generated doc + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def __init__(self): + super(HasThresholds, self).__init__() + #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. + """ + self._paramMap[self.thresholds] = value + return self + + def getThresholds(self): + """ + Gets the value of thresholds or its default value. + """ + return self.getOrDefault(self.thresholds) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 44f60a769566..a9503608b7f2 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -28,7 +28,8 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol): + HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, + HasStandardization): """ Linear regression. @@ -63,38 +64,30 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction TypeError: Method setParams forces keyword arguments. """ - # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") - self._setDefault(maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -103,19 +96,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - class LinearRegressionModel(JavaModel): """ From b01b26260625f0ba14e5f3010207666d62d93864 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Sep 2015 08:52:28 -0700 Subject: [PATCH 0590/1824] [SPARK-9773] [ML] [PySpark] Add Python API for MultilayerPerceptronClassifier Add Python API for ```MultilayerPerceptronClassifier```. Author: Yanbo Liang Closes #8067 from yanboliang/SPARK-9773. --- .../MultilayerPerceptronClassifier.scala | 9 ++ python/pyspark/ml/classification.py | 132 +++++++++++++++++- 2 files changed, 140 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 82fc80c58054..5f60dea91fcf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} @@ -181,6 +183,13 @@ class MultilayerPerceptronClassificationModel private[ml] ( private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + /** + * Returns layers in a Java List. + */ + private[ml] def javaLayers: java.util.List[Int] = { + layers.toList.asJava + } + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 22bdd1b322ac..88815e561f57 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -26,7 +26,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', - 'NaiveBayesModel'] + 'NaiveBayesModel', 'MultilayerPerceptronClassifier', + 'MultilayerPerceptronClassificationModel'] @inherit_doc @@ -755,6 +756,135 @@ def theta(self): return self._call_java("theta") +@inherit_doc +class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasMaxIter, HasTol, HasSeed): + """ + Classifier trainer based on the Multilayer Perceptron. + Each layer has sigmoid activation function, output layer has softmax. + Number of inputs has to be equal to the size of feature vectors. + Number of outputs has to be equal to the total number of labels. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (0.0, Vectors.dense([0.0, 0.0])), + ... (1.0, Vectors.dense([0.0, 1.0])), + ... (1.0, Vectors.dense([1.0, 0.0])), + ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"]) + >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11) + >>> model = mlp.fit(df) + >>> model.layers + [2, 5, 2] + >>> model.weights.size + 27 + >>> testDF = sqlContext.createDataFrame([ + ... (Vectors.dense([1.0, 0.0]),), + ... (Vectors.dense([0.0, 0.0]),)], ["features"]) + >>> model.transform(testDF).show() + +---------+----------+ + | features|prediction| + +---------+----------+ + |[1.0,0.0]| 1.0| + |[0.0,0.0]| 0.0| + +---------+----------+ + ... + """ + + # a placeholder to make it appear in the generated doc + layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + + "neurons and output layer of 10 neurons, default is [1, 1].") + blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is more than " + + "remaining data in a partition then it is adjusted to the size of this " + + "data. Recommended size is between 10 and 1000, default is 128.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + """ + super(MultilayerPerceptronClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) + self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " + + "100 neurons and output layer of 10 neurons, default is [1, 1].") + self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is " + + "more than remaining data in a partition then it is adjusted to " + + "the size of this data. Recommended size is between 10 and 1000, " + + "default is 128.") + self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + Sets params for MultilayerPerceptronClassifier. + """ + kwargs = self.setParams._input_kwargs + if layers is None: + return self._set(**kwargs).setLayers([1, 1]) + else: + return self._set(**kwargs) + + def _create_model(self, java_model): + return MultilayerPerceptronClassificationModel(java_model) + + def setLayers(self, value): + """ + Sets the value of :py:attr:`layers`. + """ + self._paramMap[self.layers] = value + return self + + def getLayers(self): + """ + Gets the value of layers or its default value. + """ + return self.getOrDefault(self.layers) + + def setBlockSize(self, value): + """ + Sets the value of :py:attr:`blockSize`. + """ + self._paramMap[self.blockSize] = value + return self + + def getBlockSize(self): + """ + Gets the value of blockSize or its default value. + """ + return self.getOrDefault(self.blockSize) + + +class MultilayerPerceptronClassificationModel(JavaModel): + """ + Model fitted by MultilayerPerceptronClassifier. + """ + + @property + def layers(self): + """ + array of layer sizes including input and output layers. + """ + return self._call_java("javaLayers") + + @property + def weights(self): + """ + vector of initial weights for the model that consists of the weights of layers. + """ + return self._call_java("weights") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 960d2d0ac6b5a22242a922f87f745f7d1f736181 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 11 Sep 2015 08:53:40 -0700 Subject: [PATCH 0591/1824] [SPARK-10537] [ML] document LIBSVM source options in public API doc and some minor improvements We should document options in public API doc. Otherwise, it is hard to find out the options without looking at the code. I tried to make `DefaultSource` private and put the documentation to package doc. However, since then there exists no public class under `source.libsvm`, the Java package doc doesn't show up in the generated html file (http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4492654). So I put the doc to `DefaultSource` instead. There are several minor updates in this PR: 1. Do `vectorType == "sparse"` only once. 2. Update `hashCode` and `equals`. 3. Remove inherited doc. 4. Delete temp dir in `afterAll`. Lewuathe Author: Xiangrui Meng Closes #8699 from mengxr/SPARK-10537. --- .../ml/source/libsvm/LibSVMRelation.scala | 71 ++++++++++++------- .../{ => libsvm}/JavaLibSVMRelationSuite.java | 24 +++---- .../{ => libsvm}/LibSVMRelationSuite.scala | 14 ++-- 3 files changed, 66 insertions(+), 43 deletions(-) rename mllib/src/test/java/org/apache/spark/ml/source/{ => libsvm}/JavaLibSVMRelationSuite.java (79%) rename mllib/src/test/scala/org/apache/spark/ml/source/{ => libsvm}/LibSVMRelationSuite.scala (88%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index b12cb62a4ef1..1f627777fc68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -21,12 +21,12 @@ import com.google.common.base.Objects import org.apache.spark.Logging import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{StructType, StructField, DoubleType} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** * LibSVMRelation provides the DataFrame constructed from LibSVM format data. @@ -35,7 +35,7 @@ import org.apache.spark.sql.sources._ * @param vectorType The type of vector. It can be 'sparse' or 'dense' * @param sqlContext The Spark SQLContext */ -private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) +private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with Logging with Serializable { @@ -47,27 +47,56 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec override def buildScan(): RDD[Row] = { val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) - + val sparse = vectorType == "sparse" baseRdd.map { pt => - val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse + val features = if (sparse) pt.features.toSparse else pt.features.toDense Row(pt.label, features) } } override def hashCode(): Int = { - Objects.hashCode(path, schema) + Objects.hashCode(path, Double.box(numFeatures), vectorType) } override def equals(other: Any): Boolean = other match { - case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema) - case _ => false + case that: LibSVMRelation => + path == that.path && + numFeatures == that.numFeatures && + vectorType == that.vectorType + case _ => + false } - } /** - * This is used for creating DataFrame from LibSVM format file. - * The LibSVM file path must be specified to DefaultSource. + * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. + * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and + * `features` containing feature vectors stored as [[Vector]]s. + * + * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and + * optionally specify options, for example: + * {{{ + * // Scala + * val df = sqlContext.read.format("libsvm") + * .option("numFeatures", "780") + * .load("data/mllib/sample_libsvm_data.txt") + * + * // Java + * DataFrame df = sqlContext.read.format("libsvm") + * .option("numFeatures, "780") + * .load("data/mllib/sample_libsvm_data.txt"); + * }}} + * + * LIBSVM data source supports the following options: + * - "numFeatures": number of features. + * If unspecified or nonpositive, the number of features will be determined automatically at the + * cost of one additional pass. + * This is also useful when the dataset is already split into multiple files and you want to load + * them separately, because some features may not present in certain files, which leads to + * inconsistent feature dimensions. + * - "vectorType": feature vector type, "sparse" (default) or "dense". + * + * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] */ @Since("1.6.0") class DefaultSource extends RelationProvider with DataSourceRegister { @@ -75,24 +104,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" - private def checkPath(parameters: Map[String, String]): String = { - require(parameters.contains("path"), "'path' must be specified") - parameters.get("path").get - } - - /** - * Returns a new base relation with the given parameters. - * Note: the parameters' keywords are case insensitive and this insensitivity is enforced - * by the Map that is passed to the function. - */ + @Since("1.6.0") override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) : BaseRelation = { - val path = checkPath(parameters) + val path = parameters.getOrElse("path", + throw new IllegalArgumentException("'path' must be specified")) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt - /** - * featuresType can be selected "dense" or "sparse". - * This parameter decides the type of returned feature vector. - */ val vectorType = parameters.getOrElse("vectorType", "sparse") new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) } diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java similarity index 79% rename from mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java rename to mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 11fa4eec0ccf..2976b38e4503 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.source; +package org.apache.spark.ml.source.libsvm; import java.io.File; import java.io.IOException; @@ -42,34 +42,34 @@ */ public class JavaLibSVMRelationSuite { private transient JavaSparkContext jsc; - private transient SQLContext jsql; - private transient DataFrame dataset; + private transient SQLContext sqlContext; - private File tmpDir; - private File path; + private File tempDir; + private String path; @Before public void setUp() throws IOException { jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); - jsql = new SQLContext(jsc); - - tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); - path = new File(tmpDir.getPath(), "part-00000"); + sqlContext = new SQLContext(jsc); + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); + File file = new File(tempDir, "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; - Files.write(s, path, Charsets.US_ASCII); + Files.write(s, file, Charsets.US_ASCII); + path = tempDir.toURI().toString(); } @After public void tearDown() { jsc.stop(); jsc = null; - Utils.deleteRecursively(tmpDir); + Utils.deleteRecursively(tempDir); } @Test public void verifyLibSVMDF() { - dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath()); + DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); Row r = dataset.first(); diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala similarity index 88% rename from mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 8ed134128c8d..997f574e51f6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.source +package org.apache.spark.ml.source.libsvm import java.io.File @@ -23,11 +23,12 @@ import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var tempDir: File = _ var path: String = _ override def beforeAll(): Unit = { @@ -38,12 +39,17 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - val tempDir = Utils.createTempDir() - val file = new File(tempDir.getPath, "part-00000") + tempDir = Utils.createTempDir() + val file = new File(tempDir, "part-00000") Files.write(lines, file, Charsets.US_ASCII) path = tempDir.toURI.toString } + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() + } + test("select as sparse vector") { val df = sqlContext.read.format("libsvm").load(path) assert(df.columns(0) == "label") From 2e3a280754a28dc36a71b9ff988e34cbf457f6c3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 11 Sep 2015 08:55:35 -0700 Subject: [PATCH 0592/1824] [MINOR] [MLLIB] [ML] [DOC] Minor doc fixes for StringIndexer and MetadataUtils MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes: * Make Scala doc for StringIndexerInverse clearer. Also remove Scala doc from transformSchema, so that the doc is inherited. * MetadataUtils.scala: “ Helper utilities for tree-based algorithms†—> not just trees anymore CC: holdenk mengxr Author: Joseph K. Bradley Closes #8679 from jkbradley/doc-fixes-1.5. --- .../spark/ml/feature/StringIndexer.scala | 31 +++++++------------ .../apache/spark/ml/util/MetadataUtils.scala | 2 +- python/pyspark/ml/feature.py | 16 +++++----- 3 files changed, 20 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index b6482ffe0b2e..3a4ab9a85764 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -181,10 +181,10 @@ class StringIndexerModel ( /** * :: Experimental :: - * A [[Transformer]] that maps a column of string indices back to a new column of corresponding - * string values using either the ML attributes of the input column, or if provided using the labels - * supplied by the user. - * All original columns are kept during transformation. + * A [[Transformer]] that maps a column of indices back to a new column of corresponding + * string values. + * The index-string mapping is either from the ML attributes of the input column, + * or from user-supplied labels (which take precedence over ML attributes). * * @see [[StringIndexer]] for converting strings into indices */ @@ -202,32 +202,23 @@ class IndexToString private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. The default value is an empty array, - * but the empty array is ignored and column metadata used instead. - * @group setParam - */ + /** @group setParam */ def setLabels(value: Array[String]): this.type = set(labels, value) /** - * Param for array of labels. - * Optional labels to be provided by the user. - * Default: Empty array, in which case column metadata is used for labels. + * Optional param for array of labels specifying index-string mapping. + * + * Default: Empty array, in which case [[inputCol]] metadata is used for labels. * @group param */ final val labels: StringArrayParam = new StringArrayParam(this, "labels", - "array of labels, if not provided metadata from inputCol is used instead.") + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") setDefault(labels, Array.empty[String]) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. - * @group getParam - */ + /** @group getParam */ final def getLabels: Array[String] = $(labels) - /** Transform the schema for the inverse transformation */ override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index fcb517b5f735..96a38a3bde96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructField /** - * Helper utilities for tree-based algorithms + * Helper utilities for algorithms using ML metadata */ private[spark] object MetadataUtils { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 71dc636b83ea..97cbee73a05e 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -985,17 +985,17 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): """ .. note:: Experimental - A :py:class:`Transformer` that maps a column of string indices back to a new column of - corresponding string values using either the ML attributes of the input column, or if - provided using the labels supplied by the user. - All original columns are kept during transformation. + A :py:class:`Transformer` that maps a column of indices back to a new column of + corresponding string values. + The index-string mapping is either from the ML attributes of the input column, + or from user-supplied labels (which take precedence over ML attributes). See L{StringIndexer} for converting strings into indices. """ # a placeholder to make the labels show up in generated doc labels = Param(Params._dummy(), "labels", - "Optional array of labels to be provided by the user, if not supplied or " + - "empty, column metadata is read for labels") + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") @keyword_only def __init__(self, inputCol=None, outputCol=None, labels=None): @@ -1006,8 +1006,8 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) self.labels = Param(self, "labels", - "Optional array of labels to be provided by the user, if not " + - "supplied or empty, column metadata is read for labels") + "Optional array of labels specifying index-string mapping. If not" + + " provided or if empty, then metadata from inputCol is used instead.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) From 6ce0886eb0916a985db142c0b6d2c2b14db5063d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 11 Sep 2015 09:42:53 -0700 Subject: [PATCH 0593/1824] [SPARK-10540] [SQL] Ignore HadoopFsRelationTest's "test all data types" if it is too flaky If hadoopFsRelationSuites's "test all data types" is too flaky we can disable it for now. https://issues.apache.org/jira/browse/SPARK-10540 Author: Yin Huai Closes #8705 from yhuai/SPARK-10540-ignore. --- .../org/apache/spark/sql/sources/hadoopFsRelationSuites.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 24f43cf7c15c..13223c61584b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -100,7 +100,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - test("test all data types") { + ignore("test all data types") { withTempPath { file => // Create the schema. val struct = From 5f46444765a377696af76af6e2c77ab14bfdab8e Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 11 Sep 2015 10:32:35 -0700 Subject: [PATCH 0594/1824] [SPARK-8530] [ML] add python API for MinMaxScaler jira: https://issues.apache.org/jira/browse/SPARK-8530 add python API for MinMaxScaler jira for MinMaxScaler: https://issues.apache.org/jira/browse/SPARK-7514 Author: Yuhao Yang Closes #7150 from hhbyyh/pythonMinMax. --- python/pyspark/ml/feature.py | 104 +++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 97cbee73a05e..92db8df80280 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,11 +27,11 @@ from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', - 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', - 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', - 'Word2Vec', 'Word2VecModel'] + 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', + 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer', + 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', + 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -406,6 +406,100 @@ class IDFModel(JavaModel): """ +@inherit_doc +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Rescale each feature individually to a common range [min, max] linearly using column summary + statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + feature E is calculated as, + + Rescaled(e_i) = (e_i - E_min) / (E_max - E_min) * (max - min) + min + + For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min) + + Note that since zero values will probably be transformed to non-zero values, output of the + transformer will be DenseVector even for sparse input. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled") + >>> model = mmScaler.fit(df) + >>> model.transform(df).show() + +-----+------+ + | a|scaled| + +-----+------+ + |[0.0]| [0.0]| + |[2.0]| [1.0]| + +-----+------+ + ... + """ + + # a placeholder to make it appear in the generated doc + min = Param(Params._dummy(), "min", "Lower bound of the output feature range") + max = Param(Params._dummy(), "max", "Upper bound of the output feature range") + + @keyword_only + def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + """ + super(MinMaxScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) + self.min = Param(self, "min", "Lower bound of the output feature range") + self.max = Param(self, "max", "Upper bound of the output feature range") + self._setDefault(min=0.0, max=1.0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + Sets params for this MinMaxScaler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMin(self, value): + """ + Sets the value of :py:attr:`min`. + """ + self._paramMap[self.min] = value + return self + + def getMin(self): + """ + Gets the value of min or its default value. + """ + return self.getOrDefault(self.min) + + def setMax(self, value): + """ + Sets the value of :py:attr:`max`. + """ + self._paramMap[self.max] = value + return self + + def getMax(self): + """ + Gets the value of max or its default value. + """ + return self.getOrDefault(self.max) + + def _create_model(self, java_model): + return MinMaxScalerModel(java_model) + + +class MinMaxScalerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by :py:class:`MinMaxScaler`. + """ + + @inherit_doc @ignore_unicode_prefix class NGram(JavaTransformer, HasInputCol, HasOutputCol): From b231ab8938ae3c4fc2089cfc69c0d8164807d533 Mon Sep 17 00:00:00 2001 From: tedyu Date: Fri, 11 Sep 2015 21:45:45 +0100 Subject: [PATCH 0595/1824] [SPARK-10546] Check partitionId's range in ExternalSorter#spill() See this thread for background: http://search-hadoop.com/m/q3RTt0rWvIkHAE81 We should check the range of partition Id and provide meaningful message through exception. Alternatively, we can use abs() and modulo to force the partition Id into legitimate range. However, expectation is that user should correct the logic error in his / her code. Author: tedyu Closes #8703 from tedyu/master. --- .../scala/org/apache/spark/util/collection/ExternalSorter.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 138c05dff19e..31230d5978b2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -297,6 +297,8 @@ private[spark] class ExternalSorter[K, V, C]( val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { val partitionId = it.nextPartition() + require(partitionId >= 0 && partitionId < numPartitions, + s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") it.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 From c373866774c082885a50daaf7c83f3a14b0cd714 Mon Sep 17 00:00:00 2001 From: Icaro Medeiros Date: Fri, 11 Sep 2015 21:46:52 +0100 Subject: [PATCH 0596/1824] [PYTHON] Fixed typo in exception message Just fixing a typo in exception message, raised when attempting to pickle SparkContext. Author: Icaro Medeiros Closes #8724 from icaromedeiros/master. --- python/pyspark/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1b2a52ad6411..a0a1ccbeefb0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -255,7 +255,7 @@ def __getnewargs__(self): # This method is called when attempting to pickle SparkContext, which is always an error: raise Exception( "It appears that you are attempting to reference SparkContext from a broadcast " - "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "variable, action, or transformation. SparkContext can only be used on the driver, " "not in code that it run on workers. For more information, see SPARK-5063." ) From d5d647380f93f4773f9cb85ea6544892d409b5a1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 Sep 2015 14:15:16 -0700 Subject: [PATCH 0597/1824] [SPARK-10442] [SQL] fix string to boolean cast When we cast string to boolean in hive, it returns `true` if the length of string is > 0, and spark SQL follows this behavior. However, this behavior is very different from other SQL systems: 1. [presto](https://github.com/facebook/presto/blob/master/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java#L89-L118) will return `true` for 't' 'true' '1', `false` for 'f' 'false' '0', throw exception for others. 2. [redshift](http://docs.aws.amazon.com/redshift/latest/dg/r_Boolean_type.html) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 3. [postgresql](http://www.postgresql.org/docs/devel/static/datatype-boolean.html) will return `true` for 't' 'true' 'y' 'yes' 'on' '1', `false` for 'f' 'false' 'n' 'no' 'off' '0', throw exception for others. 4. [vertica](https://my.vertica.com/docs/5.0/HTML/Master/2983.htm) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 5. [impala](http://www.cloudera.com/content/cloudera/en/documentation/cloudera-impala/latest/topics/impala_boolean.html) throw exception when try to cast string to boolean. 6. mysql, oracle, sqlserver don't have boolean type Whether we should change the cast behavior according to other SQL system or not is not decided yet, this PR is a test to see if we changed, how many compatibility tests will fail. Author: Wenchen Fan Closes #8698 from cloud-fan/string2boolean. --- .../spark/sql/catalyst/expressions/Cast.scala | 24 +++++++- .../spark/sql/catalyst/util/StringUtils.scala | 8 +++ .../sql/catalyst/expressions/CastSuite.scala | 61 ++++++++++++------- .../sql/sources/hadoopFsRelationSuites.scala | 13 ++++ 4 files changed, 82 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 2db954257be3..f0bce388d959 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType) // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, _.numBytes() != 0) + buildCast[UTF8String](_, s => { + if (StringUtils.isTrueString(s)) { + true + } else if (StringUtils.isFalseString(s)) { + false + } else { + null + } + }) case TimestampType => buildCast[Long](_, t => t != 0) case DateType => @@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + (c, evPrim, evNull) => + s""" + if ($stringUtils.isTrueString($c)) { + $evPrim = true; + } else if ($stringUtils.isFalseString($c)) { + $evPrim = false; + } else { + $evNull = true; + } + """ case TimestampType => (c, evPrim, evNull) => s"$evPrim = $c != 0;" case DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 9ddfb3a0d375..c2eeb3c5650a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.Pattern +import org.apache.spark.unsafe.types.UTF8String + object StringUtils { // replace the _ with .{1} exactly match 1 time of any character @@ -44,4 +46,10 @@ object StringUtils { v } } + + private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) + private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) + + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1ad70733eae0..f4db4da7646f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from array") { - val array = Literal.create(Seq("123", "abc", "", null), + val array = Literal.create(Seq("123", "true", "f", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), + val array_notNull = Literal.create(Seq("123", "true", "f"), ArrayType(StringType, containsNull = false)) checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) @@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false, null)) + checkEvaluation(ret, Seq(null, true, false, null)) } { val ret = cast(array, ArrayType(BooleanType, containsNull = false)) @@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { @@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from map") { val map = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, valueContainsNull = false)) checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) @@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null)) } { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) @@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) @@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString(""), + UTF8String.fromString("true"), + UTF8String.fromString("f"), null), StructType(Seq( StructField("a", StringType, nullable = true), @@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct_notNull = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString("")), + UTF8String.fromString("true"), + UTF8String.fromString("f")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("c", BooleanType, nullable = true), StructField("d", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false, null)) + checkEvaluation(ret, InternalRow(null, true, false, null)) } { val ret = cast(struct, StructType(Seq( @@ -704,7 +704,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { val ret = cast(struct_notNull, StructType(Seq( @@ -712,7 +712,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { @@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( Row( - Seq("123", "abc", ""), - Map("a" ->"123", "b" -> "abc", "c" -> ""), + Seq("123", "true", "f"), + Map("a" ->"123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", @@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === true) checkEvaluation(ret, Row( Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), + Map("a" -> null, "b" -> true, "c" -> false), Row(0L))) } - test("case between string and interval") { + test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), @@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType), "interval 1 years 3 months -3 days") } + + test("cast string to boolean") { + checkCast("t", true) + checkCast("true", true) + checkCast("tRUe", true) + checkCast("y", true) + checkCast("yes", true) + checkCast("1", true) + + checkCast("f", false) + checkCast("false", false) + checkCast("FAlsE", false) + checkCast("n", false) + checkCast("no", false) + checkCast("0", false) + + checkEvaluation(cast("abc", BooleanType), null) + checkEvaluation(cast("", BooleanType), null) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 13223c61584b..8ffcef85668d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } + test("saveAsTable()/load() - partitioned table - boolean type") { + sqlContext.range(2) + .select('id, ('id % 2 === 0).as("b")) + .write.partitionBy("b").saveAsTable("t") + + withTable("t") { + checkAnswer( + sqlContext.table("t").sort('id), + Row(0, true) :: Row(1, false) :: Nil + ) + } + } + test("saveAsTable()/load() - partitioned table - Overwrite") { partitionedTestDF.write .format(dataSourceName) From 1eede3b254ee3793841c92971707094ac8afee35 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Fri, 11 Sep 2015 14:55:15 -0700 Subject: [PATCH 0598/1824] [SPARK-7142] [SQL] Minor enhancement to BooleanSimplification Optimizer rule. Incorporate review comments Adding changes suggested by cloud-fan in #5700 cc marmbrus Author: Yash Datta Closes #8716 from saucam/bool_simp. --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d9b50f3c97da..0f4caec7451a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -435,10 +435,10 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // a && a => a case (l, r) if l fastEquals r => l // a && (not(a) || b) => a && b - case (l, Or(l1, r)) if (Not(l) fastEquals l1) => And(l, r) - case (l, Or(r, l1)) if (Not(l) fastEquals l1) => And(l, r) - case (Or(l, l1), r) if (l1 fastEquals Not(r)) => And(l, r) - case (Or(l1, l), r) if (l1 fastEquals Not(r)) => And(l, r) + case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r) + case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r) + case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r) + case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r) // (a || b) && (a || c) => a || (b && c) case _ => // 1. Split left and right to get the disjunctive predicates, From e626ac5f5c27dcc74113070f2fec03682bcd12bd Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 11 Sep 2015 15:00:13 -0700 Subject: [PATCH 0599/1824] [SPARK-9992] [SPARK-9994] [SPARK-9998] [SQL] Implement the local TopK, sample and intersect operators This PR is in conflict with #8535. I will update this one when #8535 gets merged. Author: zsxwing Closes #8573 from zsxwing/more-local-operators. --- .../spark/sql/execution/basicOperators.scala | 2 +- .../sql/execution/local/IntersectNode.scala | 63 ++++++++++++++ .../spark/sql/execution/local/LocalNode.scala | 5 ++ .../sql/execution/local/SampleNode.scala | 82 +++++++++++++++++++ .../local/TakeOrderedAndProjectNode.scala | 73 +++++++++++++++++ .../execution/local/IntersectNodeSuite.scala | 35 ++++++++ .../sql/execution/local/SampleNodeSuite.scala | 40 +++++++++ .../TakeOrderedAndProjectNodeSuite.scala | 54 ++++++++++++ 8 files changed, 353 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 3f68b05a24f4..bf6d44c098ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed - * @param child the QueryPlan + * @param child the SparkPlan */ @DeveloperApi case class Sample( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala new file mode 100644 index 000000000000..740d485f8d9e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala @@ -0,0 +1,63 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.collection.mutable + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode) + extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = left.output + + private[this] var leftRows: mutable.HashSet[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + left.open() + leftRows = mutable.HashSet[InternalRow]() + while (left.next()) { + leftRows += left.fetch().copy() + } + left.close() + right.open() + } + + override def next(): Boolean = { + currentRow = null + while (currentRow == null && right.next()) { + currentRow = right.fetch() + if (!leftRows.contains(currentRow)) { + currentRow = null + } + } + currentRow != null + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index c4f8ae304db3..a2c275db9b35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -69,6 +69,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging */ def close(): Unit + /** + * Returns the content through the [[Iterator]] interface. + */ + final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) + /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala new file mode 100644 index 000000000000..abf3df1c0c2a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import java.util.Random + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + +/** + * Sample the dataset. + * + * @param conf the SQLConf + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the LocalNode + */ +case class SampleNode( + conf: SQLConf, + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var iterator: Iterator[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + child.open() + val (sampler, _seed) = if (withReplacement) { + val random = new Random(seed) + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), + // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result + // of DataFrame + random.nextLong()) + } else { + (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + } + sampler.setSeed(_seed) + iterator = sampler.sample(child.asIterator) + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala new file mode 100644 index 000000000000..53f1dcc65d8c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.util.BoundedPriorityQueue + +case class TakeOrderedAndProjectNode( + conf: SQLConf, + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: LocalNode) extends UnaryLocalNode(conf) { + + private[this] var projection: Option[Projection] = _ + private[this] var ord: InterpretedOrdering = _ + private[this] var iterator: Iterator[InternalRow] = _ + private[this] var currentRow: InternalRow = _ + + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + + override def open(): Unit = { + child.open() + projection = projectList.map(new InterpretedProjection(_, child.output)) + ord = new InterpretedOrdering(sortOrder, child.output) + // Priority keeps the largest elements, so let's reverse the ordering. + val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) + while (child.next()) { + queue += child.fetch() + } + // Close it eagerly since we don't need it. + child.close() + iterator = queue.iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + val _currentRow = iterator.next() + currentRow = projection match { + case Some(p) => p(_currentRow) + case None => _currentRow + } + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala new file mode 100644 index 000000000000..7deaa375fcfc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -0,0 +1,35 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +class IntersectNodeSuite extends LocalNodeTest { + + import testImplicits._ + + test("basic") { + val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") + + checkAnswer2( + input1, + input2, + (node1, node2) => IntersectNode(conf, node1, node2), + input1.intersect(input2).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala new file mode 100644 index 000000000000..87a7da453999 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +class SampleNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def testSample(withReplacement: Boolean): Unit = { + test(s"withReplacement: $withReplacement") { + val seed = 0L + val input = sqlContext.sparkContext. + parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition + toDF("key", "value") + checkAnswer( + input, + node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), + input.sample(withReplacement, 0.3, seed).collect() + ) + } + } + + testSample(withReplacement = true) + testSample(withReplacement = false) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala new file mode 100644 index 000000000000..ff28b24eeff1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} + +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + sortOrder + } + + private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { + val testCaseName = if (desc) "desc" else "asc" + test(testCaseName) { + val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val sortColumn = if (desc) input.col("key").desc else input.col("key") + checkAnswer( + input, + node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), + input.sort(sortColumn).limit(5).collect() + ) + } + } + + testTakeOrderedAndProjectNode(desc = false) + testTakeOrderedAndProjectNode(desc = true) +} From c2af42b5f32287ff595ad027a8191d4b75702d8d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 11 Sep 2015 15:01:37 -0700 Subject: [PATCH 0600/1824] [SPARK-9990] [SQL] Local hash join follow-ups 1. Hide `LocalNodeIterator` behind the `LocalNode#asIterator` method 2. Add tests for this Author: Andrew Or Closes #8708 from andrewor14/local-hash-join-follow-up. --- .../sql/execution/joins/HashedRelation.scala | 7 +- .../sql/execution/local/HashJoinNode.scala | 3 +- .../spark/sql/execution/local/LocalNode.scala | 4 +- .../sql/execution/local/LocalNodeSuite.scala | 116 ++++++++++++++++++ 4 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0cff21ca618b..bc255b27502b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,7 +25,8 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.local.LocalNode +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -113,6 +114,10 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR private[execution] object HashedRelation { + def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = { + apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator) + } + def apply( input: Iterator[InternalRow], numInputRows: LongSQLMetric, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala index a3e68d6a7c34..e7b24e3fca2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -75,8 +75,7 @@ case class HashJoinNode( override def open(): Unit = { buildNode.open() - hashed = HashedRelation.apply( - new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, buildSideKeyGenerator) + hashed = HashedRelation(buildNode, buildSideKeyGenerator) streamedNode.open() joinRow = new JoinedRow resultProjection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index a2c275db9b35..e540ef8555eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -77,7 +77,7 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ - def collect(): Seq[Row] = { + final def collect(): Seq[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) val result = new scala.collection.mutable.ArrayBuffer[Row] open() @@ -140,7 +140,7 @@ abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { /** * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface. */ -private[local] class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { +private class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { private var nextRow: InternalRow = _ override def hasNext: Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala new file mode 100644 index 000000000000..b89fa46f8b3b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -0,0 +1,116 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.IntegerType + +class LocalNodeSuite extends SparkFunSuite { + private val data = (1 to 100).toArray + + test("basic open, next, fetch, close") { + val node = new DummyLocalNode(data) + assert(!node.isOpen) + node.open() + assert(node.isOpen) + data.foreach { i => + assert(node.next()) + // fetch should be idempotent + val fetched = node.fetch() + assert(node.fetch() === fetched) + assert(node.fetch() === fetched) + assert(node.fetch().numFields === 1) + assert(node.fetch().getInt(0) === i) + } + assert(!node.next()) + node.close() + assert(!node.isOpen) + } + + test("asIterator") { + val node = new DummyLocalNode(data) + val iter = node.asIterator + node.open() + data.foreach { i => + // hasNext should be idempotent + assert(iter.hasNext) + assert(iter.hasNext) + val item = iter.next() + assert(item.numFields === 1) + assert(item.getInt(0) === i) + } + intercept[NoSuchElementException] { + iter.next() + } + node.close() + } + + test("collect") { + val node = new DummyLocalNode(data) + node.open() + val collected = node.collect() + assert(collected.size === data.size) + assert(collected.forall(_.size === 1)) + assert(collected.map(_.getInt(0)) === data) + node.close() + } + +} + +/** + * A dummy [[LocalNode]] that just returns one row per integer in the input. + */ +private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) { + private var index = Int.MinValue + + def this(input: Array[Int]) { + this(new SQLConf, input) + } + + def isOpen: Boolean = { + index != Int.MinValue + } + + override def output: Seq[Attribute] = { + Seq(AttributeReference("something", IntegerType)()) + } + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + val values = Array(input(index).asInstanceOf[Any]) + new GenericInternalRow(values) + } + + override def close(): Unit = { + index = Int.MinValue + } +} From d74c6a143cbd060c25bf14a8d306841b3ec55d03 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 11 Sep 2015 15:02:59 -0700 Subject: [PATCH 0601/1824] [SPARK-10564] ThreadingSuite: assertion failures in threads don't fail the test This commit ensures if an assertion fails within a thread, it will ultimately fail the test. Otherwise we end up potentially masking real bugs by not propagating assertion failures properly. Author: Andrew Or Closes #8723 from andrewor14/fix-threading-suite. --- .../org/apache/spark/ThreadingSuite.scala | 68 ++++++++++++------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 48509f0759a3..cda2b245526f 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -119,23 +119,30 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { val nums = sc.parallelize(1 to 2, 2) val sem = new Semaphore(0) ThreadingSuiteState.clear() + var throwable: Option[Throwable] = None for (i <- 0 until 2) { new Thread { override def run() { - val ans = nums.map(number => { - val running = ThreadingSuiteState.runningThreads - running.getAndIncrement() - val time = System.currentTimeMillis() - while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { - Thread.sleep(100) - } - if (running.get() != 4) { - ThreadingSuiteState.failed.set(true) - } - number - }).collect() - assert(ans.toList === List(1, 2)) - sem.release() + try { + val ans = nums.map(number => { + val running = ThreadingSuiteState.runningThreads + running.getAndIncrement() + val time = System.currentTimeMillis() + while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { + Thread.sleep(100) + } + if (running.get() != 4) { + ThreadingSuiteState.failed.set(true) + } + number + }).collect() + assert(ans.toList === List(1, 2)) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } }.start() } @@ -145,18 +152,25 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } + throwable.foreach { t => throw t } } test("set local properties in different thread") { sc = new SparkContext("local", "test") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -165,20 +179,27 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { sem.acquire(5) assert(sc.getLocalProperty("test") === null) + throwable.foreach { t => throw t } } test("set and get local properties in parent-children thread") { sc = new SparkContext("local", "test") sc.setLocalProperty("test", "parent") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - assert(sc.getLocalProperty("test") === "parent") - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + assert(sc.getLocalProperty("test") === "parent") + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -188,6 +209,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { sem.acquire(5) assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) + throwable.foreach { t => throw t } } test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { From c34fc19765bdf55365cdce78d9ba11b220b73bb6 Mon Sep 17 00:00:00 2001 From: 0x0FFF Date: Fri, 11 Sep 2015 15:19:04 -0700 Subject: [PATCH 0602/1824] [SPARK-9014] [SQL] Allow Python spark API to use built-in exponential operator This PR addresses (SPARK-9014)[https://issues.apache.org/jira/browse/SPARK-9014] Added functionality: `Column` object in Python now supports exponential operator `**` Example: ``` from pyspark.sql import * df = sqlContext.createDataFrame([Row(a=2)]) df.select(3**df.a,df.a**3,df.a**df.a).collect() ``` Outputs: ``` [Row(POWER(3.0, a)=9.0, POWER(a, 3.0)=8.0, POWER(a, a)=4.0)] ``` Author: 0x0FFF Closes #8658 from 0x0FFF/SPARK-9014. --- python/pyspark/sql/column.py | 13 +++++++++++++ python/pyspark/sql/tests.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 573f65f5bf09..9ca8e1f264cf 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -91,6 +91,17 @@ def _(self): return _ +def _bin_func_op(name, reverse=False, doc="binary function"): + def _(self, other): + sc = SparkContext._active_spark_context + fn = getattr(sc._jvm.functions, name) + jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) + njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc) + return Column(njc) + _.__doc__ = doc + return _ + + def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator """ @@ -151,6 +162,8 @@ def __init__(self, jc): __rdiv__ = _reverse_op("divide") __rtruediv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") + __pow__ = _bin_func_op("pow") + __rpow__ = _bin_func_op("pow", reverse=True) # logistic operators __eq__ = _bin_op("equalTo") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index eb449e8679fa..f2172b7a27d8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -568,7 +568,7 @@ def test_column_operators(self): cs = self.df.value c = ci == cs self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) self.assertTrue(all(isinstance(c, Column) for c in rcc)) cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) From 6d8367807cb62c2cb139cee1d039dc8b12c63385 Mon Sep 17 00:00:00 2001 From: Daniel Imfeld Date: Sat, 12 Sep 2015 09:19:59 +0100 Subject: [PATCH 0603/1824] [SPARK-10566] [CORE] SnappyCompressionCodec init exception handling masks important error information When throwing an IllegalArgumentException in SnappyCompressionCodec.init, chain the existing exception. This allows potentially important debugging info to be passed to the user. Manual testing shows the exception chained properly, and the test suite still looks fine as well. This contribution is my original work and I license the work to the project under the project's open source license. Author: Daniel Imfeld Closes #8725 from dimfeld/dimfeld-patch-1. --- core/src/main/scala/org/apache/spark/io/CompressionCodec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 607d5a321efc..9dc36704a676 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -148,7 +148,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { try { Snappy.getNativeLibraryVersion } catch { - case e: Error => throw new IllegalArgumentException + case e: Error => throw new IllegalArgumentException(e) } override def compressedOutputStream(s: OutputStream): OutputStream = { From 8285e3b0d3dc0eff669eba993742dfe0401116f9 Mon Sep 17 00:00:00 2001 From: Nithin Asokan Date: Sat, 12 Sep 2015 09:50:49 +0100 Subject: [PATCH 0604/1824] [SPARK-10554] [CORE] Fix NPE with ShutdownHook https://issues.apache.org/jira/browse/SPARK-10554 Fixes NPE when ShutdownHook tries to cleanup temporary folders Author: Nithin Asokan Closes #8720 from nasokan/SPARK-10554. --- .../scala/org/apache/spark/storage/DiskBlockManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3f8d26e1d4ca..f7e84a2c2e14 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -164,7 +164,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. - if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { + // Also blockManagerId could be null if block manager is not initialized properly. + if (!blockManager.externalShuffleServiceEnabled || + (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { From 22730ad54d681ad30e63fe910e8d89360853177d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 12 Sep 2015 10:40:10 +0100 Subject: [PATCH 0605/1824] [SPARK-10547] [TEST] Streamline / improve style of Java API tests Fix a few Java API test style issues: unused generic types, exceptions, wrong assert argument order Author: Sean Owen Closes #8706 from srowen/SPARK-10547. --- .../java/org/apache/spark/JavaAPISuite.java | 451 ++++++----- .../kafka/JavaDirectKafkaStreamSuite.java | 24 +- .../streaming/kafka/JavaKafkaRDDSuite.java | 17 +- .../streaming/kafka/JavaKafkaStreamSuite.java | 14 +- .../twitter/JavaTwitterStreamSuite.java | 4 +- .../java/org/apache/spark/Java8APISuite.java | 46 +- .../spark/sql/JavaApplySchemaSuite.java | 39 +- .../apache/spark/sql/JavaDataFrameSuite.java | 29 +- .../org/apache/spark/sql/JavaRowSuite.java | 15 +- .../org/apache/spark/sql/JavaUDFSuite.java | 9 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 10 +- .../spark/sql/hive/JavaDataFrameSuite.java | 8 +- .../hive/JavaMetastoreDataSourcesSuite.java | 12 +- .../apache/spark/streaming/JavaAPISuite.java | 752 +++++++++--------- .../spark/streaming/JavaReceiverAPISuite.java | 86 +- 15 files changed, 755 insertions(+), 761 deletions(-) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index ebd3d61ae732..fd8f7f39b7cc 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -90,7 +90,7 @@ public void sparkContextUnion() { JavaRDD sUnion = sc.union(s1, s2); Assert.assertEquals(4, sUnion.count()); // List - List> list = new ArrayList>(); + List> list = new ArrayList<>(); list.add(s2); sUnion = sc.union(s1, list); Assert.assertEquals(4, sUnion.count()); @@ -103,9 +103,9 @@ public void sparkContextUnion() { Assert.assertEquals(4, dUnion.count()); // Union of JavaPairRDDs - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pUnion = sc.union(p1, p2); @@ -133,9 +133,9 @@ public void intersection() { JavaDoubleRDD dIntersection = d1.intersection(d2); Assert.assertEquals(2, dIntersection.count()); - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pIntersection = p1.intersection(p2); @@ -165,47 +165,49 @@ public void randomSplit() { @Test public void sortByKey() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaPairRDD rdd = sc.parallelizePairs(pairs); // Default comparator JavaPairRDD sortedRDD = rdd.sortByKey(); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // Custom comparator sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); } @SuppressWarnings("unchecked") @Test public void repartitionAndSortWithinPartitions() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 5)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(2, 6)); - pairs.add(new Tuple2(0, 8)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(1, 3)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 5)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(2, 6)); + pairs.add(new Tuple2<>(0, 8)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(1, 3)); JavaPairRDD rdd = sc.parallelizePairs(pairs); Partitioner partitioner = new Partitioner() { + @Override public int numPartitions() { return 2; } + @Override public int getPartition(Object key) { - return ((Integer)key).intValue() % 2; + return (Integer) key % 2; } }; @@ -214,10 +216,10 @@ public int getPartition(Object key) { Assert.assertTrue(repartitioned.partitioner().isPresent()); Assert.assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); - Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), - new Tuple2(0, 8), new Tuple2(2, 6))); - Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2(1, 3), - new Tuple2(3, 8), new Tuple2(3, 8))); + Assert.assertEquals(partitions.get(0), + Arrays.asList(new Tuple2<>(0, 5), new Tuple2<>(0, 8), new Tuple2<>(2, 6))); + Assert.assertEquals(partitions.get(1), + Arrays.asList(new Tuple2<>(1, 3), new Tuple2<>(3, 8), new Tuple2<>(3, 8))); } @Test @@ -228,35 +230,37 @@ public void emptyRDD() { @Test public void sortBy() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaRDD> rdd = sc.parallelize(pairs); // compare on first value JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._1(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // compare on second value sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._2(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2)); } @Test @@ -265,7 +269,7 @@ public void foreach() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreach(new VoidFunction() { @Override - public void call(String s) throws IOException { + public void call(String s) { accum.add(1); } }); @@ -278,7 +282,7 @@ public void foreachPartition() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreachPartition(new VoidFunction>() { @Override - public void call(Iterator iter) throws IOException { + public void call(Iterator iter) { while (iter.hasNext()) { iter.next(); accum.add(1); @@ -301,7 +305,7 @@ public void zipWithUniqueId() { List dataArray = Arrays.asList(1, 2, 3, 4); JavaPairRDD zip = sc.parallelize(dataArray).zipWithUniqueId(); JavaRDD indexes = zip.values(); - Assert.assertEquals(4, new HashSet(indexes.collect()).size()); + Assert.assertEquals(4, new HashSet<>(indexes.collect()).size()); } @Test @@ -317,10 +321,10 @@ public void zipWithIndex() { @Test public void lookup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") + )); Assert.assertEquals(2, categories.lookup("Oranges").size()); Assert.assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0))); } @@ -390,18 +394,17 @@ public String call(Tuple2 x) { @Test public void cogroup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD, Iterable>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); @@ -411,23 +414,22 @@ public void cogroup() { @Test public void cogroup3() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); @@ -439,27 +441,26 @@ public void cogroup3() { @Test public void cogroup4() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", "BR"), - new Tuple2("Apples", "US") + new Tuple2<>("Oranges", "BR"), + new Tuple2<>("Apples", "US") )); JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities, countries); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); @@ -471,16 +472,16 @@ public void cogroup4() { @Test public void leftOuterJoin() { JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 2), - new Tuple2(2, 1), - new Tuple2(3, 1) + new Tuple2<>(1, 1), + new Tuple2<>(1, 2), + new Tuple2<>(2, 1), + new Tuple2<>(3, 1) )); JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 'x'), - new Tuple2(2, 'y'), - new Tuple2(2, 'z'), - new Tuple2(4, 'w') + new Tuple2<>(1, 'x'), + new Tuple2<>(2, 'y'), + new Tuple2<>(2, 'z'), + new Tuple2<>(4, 'w') )); List>>> joined = rdd1.leftOuterJoin(rdd2).collect(); @@ -548,11 +549,11 @@ public Integer call(Integer a, Integer b) { public void aggregateByKey() { JavaPairRDD pairs = sc.parallelizePairs( Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(5, 1), - new Tuple2(5, 3)), 2); + new Tuple2<>(1, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(5, 1), + new Tuple2<>(5, 3)), 2); Map> sets = pairs.aggregateByKey(new HashSet(), new Function2, Integer, Set>() { @@ -570,20 +571,20 @@ public Set call(Set a, Set b) { } }).collectAsMap(); Assert.assertEquals(3, sets.size()); - Assert.assertEquals(new HashSet(Arrays.asList(1)), sets.get(1)); - Assert.assertEquals(new HashSet(Arrays.asList(2)), sets.get(3)); - Assert.assertEquals(new HashSet(Arrays.asList(1, 3)), sets.get(5)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); + Assert.assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5)); } @SuppressWarnings("unchecked") @Test public void foldByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD sums = rdd.foldByKey(0, @@ -602,11 +603,11 @@ public Integer call(Integer a, Integer b) { @Test public void reduceByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD counts = rdd.reduceByKey( @@ -690,7 +691,7 @@ public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); - Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); + Assert.assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first()); } @Test @@ -743,6 +744,7 @@ public void javaDoubleRDDHistoGram() { } private static class DoubleComparator implements Comparator, Serializable { + @Override public int compare(Double o1, Double o2) { return o1.compareTo(o2); } @@ -766,14 +768,14 @@ public void min() { public void naturalMax() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.max(); - Assert.assertTrue(4.0 == max); + Assert.assertEquals(4.0, max, 0.0); } @Test public void naturalMin() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.min(); - Assert.assertTrue(1.0 == max); + Assert.assertEquals(1.0, max, 0.0); } @Test @@ -809,7 +811,7 @@ public void reduceOnJavaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double sum = rdd.reduce(new Function2() { @Override - public Double call(Double v1, Double v2) throws Exception { + public Double call(Double v1, Double v2) { return v1 + v2; } }); @@ -844,7 +846,7 @@ public double call(Integer x) { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, x); + return new Tuple2<>(x, x); } }).cache(); pairs.collect(); @@ -870,26 +872,25 @@ public Iterable call(String x) { Assert.assertEquals("Hello", words.first()); Assert.assertEquals(11, words.count()); - JavaPairRDD pairs = rdd.flatMapToPair( + JavaPairRDD pairsRDD = rdd.flatMapToPair( new PairFlatMapFunction() { - @Override public Iterable> call(String s) { - List> pairs = new LinkedList>(); + List> pairs = new LinkedList<>(); for (String word : s.split(" ")) { - pairs.add(new Tuple2(word, word)); + pairs.add(new Tuple2<>(word, word)); } return pairs; } } ); - Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); + Assert.assertEquals(11, pairsRDD.count()); JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override public Iterable call(String s) { - List lengths = new LinkedList(); + List lengths = new LinkedList<>(); for (String word : s.split(" ")) { lengths.add((double) word.length()); } @@ -897,36 +898,36 @@ public Iterable call(String s) { } }); Assert.assertEquals(5.0, doubles.first(), 0.01); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(11, pairsRDD.count()); } @SuppressWarnings("unchecked") @Test public void mapsFromPairsToPairs() { - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - - // Regression test for SPARK-668: - JavaPairRDD swapped = pairRDD.flatMapToPair( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterable> call(Tuple2 item) { - return Collections.singletonList(item.swap()); - } + List> pairs = Arrays.asList( + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMapToPair( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) { + return Collections.singletonList(item.swap()); + } }); - swapped.collect(); + swapped.collect(); - // There was never a bug here, but it's worth testing: - pairRDD.mapToPair(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) { - return item.swap(); - } - }).collect(); + // There was never a bug here, but it's worth testing: + pairRDD.mapToPair(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) { + return item.swap(); + } + }).collect(); } @Test @@ -953,7 +954,7 @@ public void mapPartitionsWithIndex() { JavaRDD partitionSums = rdd.mapPartitionsWithIndex( new Function2, Iterator>() { @Override - public Iterator call(Integer index, Iterator iter) throws Exception { + public Iterator call(Integer index, Iterator iter) { int sum = 0; while (iter.hasNext()) { sum += iter.next(); @@ -972,8 +973,8 @@ public void repartition() { JavaRDD repartitioned1 = in1.repartition(4); List> result1 = repartitioned1.glom().collect(); Assert.assertEquals(4, result1.size()); - for (List l: result1) { - Assert.assertTrue(l.size() > 0); + for (List l : result1) { + Assert.assertFalse(l.isEmpty()); } // Growing number of partitions @@ -982,7 +983,7 @@ public void repartition() { List> result2 = repartitioned2.glom().collect(); Assert.assertEquals(2, result2.size()); for (List l: result2) { - Assert.assertTrue(l.size() > 0); + Assert.assertFalse(l.isEmpty()); } } @@ -994,9 +995,9 @@ public void persist() { Assert.assertEquals(20, doubleRDD.sum(), 0.1); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD pairRDD = sc.parallelizePairs(pairs); pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); @@ -1046,7 +1047,7 @@ public void wholeTextFiles() throws Exception { Files.write(content1, new File(tempDirName + "/part-00000")); Files.write(content2, new File(tempDirName + "/part-00001")); - Map container = new HashMap(); + Map container = new HashMap<>(); container.put(tempDirName+"/part-00000", new Text(content1).toString()); container.put(tempDirName+"/part-00001", new Text(content2).toString()); @@ -1075,16 +1076,16 @@ public void textFilesCompressed() throws IOException { public void sequenceFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); @@ -1093,7 +1094,7 @@ public Tuple2 call(Tuple2 pair) { Text.class).mapToPair(new PairFunction, Integer, String>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(pair._1().get(), pair._2().toString()); + return new Tuple2<>(pair._1().get(), pair._2().toString()); } }); Assert.assertEquals(pairs, readRDD.collect()); @@ -1110,7 +1111,7 @@ public void binaryFiles() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); @@ -1131,14 +1132,14 @@ public void binaryFilesCaching() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); readRDD.foreach(new VoidFunction>() { @Override - public void call(Tuple2 pair) throws Exception { + public void call(Tuple2 pair) { pair._2().toArray(); // force the file to read } }); @@ -1162,7 +1163,7 @@ public void binaryRecords() throws Exception { FileChannel channel1 = fos1.getChannel(); for (int i = 0; i < numOfCopies; i++) { - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); } channel1.close(); @@ -1180,24 +1181,23 @@ public void binaryRecords() throws Exception { public void writeWithNewAPIHadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } - }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, - org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + }).saveAsNewAPIHadoopFile( + outputDir, IntWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); - JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, - Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1210,24 +1210,23 @@ public String call(Tuple2 x) { public void readWithNewAPIHadoopFile() throws IOException { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, - Text.class, new Job().getConfiguration()); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, + IntWritable.class, Text.class, new Job().getConfiguration()); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1251,9 +1250,9 @@ public void objectFilesOfInts() { public void objectFilesOfComplexTypes() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.saveAsObjectFile(outputDir); @@ -1267,23 +1266,22 @@ public void objectFilesOfComplexTypes() { public void hadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + SequenceFileInputFormat.class, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1296,16 +1294,16 @@ public String call(Tuple2 x) { public void hadoopFileCompressed() { String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); @@ -1313,8 +1311,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1414,8 +1411,8 @@ public String call(Integer t) { return t.toString(); } }).collect(); - Assert.assertEquals(new Tuple2("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @Test @@ -1448,20 +1445,20 @@ public void combineByKey() { JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); Function keyFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1 % 3; } }; Function createCombinerFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1; } }; Function2 mergeValueFunction = new Function2() { @Override - public Integer call(Integer v1, Integer v2) throws Exception { + public Integer call(Integer v1, Integer v2) { return v1 + v2; } }; @@ -1496,21 +1493,21 @@ public void mapOnPairRDD() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); JavaPairRDD rdd3 = rdd2.mapToPair( new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) { - return new Tuple2(in._2(), in._1()); - } - }); + @Override + public Tuple2 call(Tuple2 in) { + return new Tuple2<>(in._2(), in._1()); + } + }); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(0, 2), - new Tuple2(1, 3), - new Tuple2(0, 4)), rdd3.collect()); + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @@ -1523,7 +1520,7 @@ public void collectPartitions() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); @@ -1534,23 +1531,23 @@ public Tuple2 call(Integer i) { Assert.assertEquals(Arrays.asList(3, 4), parts[0]); Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); - Assert.assertEquals(Arrays.asList(new Tuple2(1, 1), - new Tuple2(2, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), + new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), - new Tuple2(4, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), + new Tuple2<>(4, 0)), parts2[0]); - Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), - new Tuple2(6, 0), - new Tuple2(7, 1)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), + new Tuple2<>(6, 0), + new Tuple2<>(7, 1)), parts2[1]); } @Test public void countApproxDistinct() { - List arrayData = new ArrayList(); + List arrayData = new ArrayList<>(); int size = 100; for (int i = 0; i < 100000; i++) { arrayData.add(i % size); @@ -1561,15 +1558,15 @@ public void countApproxDistinct() { @Test public void countApproxDistinctByKey() { - List> arrayData = new ArrayList>(); + List> arrayData = new ArrayList<>(); for (int i = 10; i < 100; i++) { for (int j = 0; j < i; j++) { - arrayData.add(new Tuple2(i, j)); + arrayData.add(new Tuple2<>(i, j)); } } double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); - List> res = pairRdd.countApproxDistinctByKey(8, 0).collect(); + List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); for (Tuple2 resItem : res) { double count = (double)resItem._1(); Long resCount = (Long)resItem._2(); @@ -1587,7 +1584,7 @@ public void collectAsMapWithIntArrayValues() { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, new int[] { x }); + return new Tuple2<>(x, new int[]{x}); } }); pairRDD.collect(); // Works fine @@ -1598,7 +1595,7 @@ public Tuple2 call(Integer x) { @Test public void collectAsMapAndSerialize() throws Exception { JavaPairRDD rdd = - sc.parallelizePairs(Arrays.asList(new Tuple2("foo", 1))); + sc.parallelizePairs(Arrays.asList(new Tuple2<>("foo", 1))); Map map = rdd.collectAsMap(); ByteArrayOutputStream bytes = new ByteArrayOutputStream(); new ObjectOutputStream(bytes).writeObject(map); @@ -1615,7 +1612,7 @@ public void sampleByKey() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1623,12 +1620,12 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); Map wrCounts = (Map) (Object) wr.countByKey(); - Assert.assertTrue(wrCounts.size() == 2); + Assert.assertEquals(2, wrCounts.size()); Assert.assertTrue(wrCounts.get(0) > 0); Assert.assertTrue(wrCounts.get(1) > 0); JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); Map worCounts = (Map) (Object) wor.countByKey(); - Assert.assertTrue(worCounts.size() == 2); + Assert.assertEquals(2, worCounts.size()); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); } @@ -1641,7 +1638,7 @@ public void sampleByKeyExact() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1649,25 +1646,25 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); - Assert.assertTrue(wrExactCounts.size() == 2); + Assert.assertEquals(2, wrExactCounts.size()); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); - Assert.assertTrue(worExactCounts.size() == 2); + Assert.assertEquals(2, worExactCounts.size()); Assert.assertTrue(worExactCounts.get(0) == 2); Assert.assertTrue(worExactCounts.get(1) == 4); } private static class SomeCustomClass implements Serializable { - public SomeCustomClass() { + SomeCustomClass() { // Intentionally left blank } } @Test public void collectUnderlyingScalaRDD() { - List data = new ArrayList(); + List data = new ArrayList<>(); for (int i = 0; i < 100; i++) { data.add(new SomeCustomClass()); } @@ -1679,7 +1676,7 @@ public void collectUnderlyingScalaRDD() { private static final class BuggyMapFunction implements Function { @Override - public T call(T x) throws Exception { + public T call(T x) { throw new IllegalStateException("Custom exception!"); } } @@ -1716,7 +1713,7 @@ public void foreachAsync() throws Exception { JavaFutureAction future = rdd.foreachAsync( new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) { // intentionally left blank. } } @@ -1745,7 +1742,7 @@ public void testAsyncActionCancellation() throws Exception { JavaRDD rdd = sc.parallelize(data, 1); JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) throws InterruptedException { Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. } }); diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 9db07d0507fe..fbdfbf7e509b 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -75,11 +75,11 @@ public void testKafkaStream() throws InterruptedException { String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); - HashSet sent = new HashSet(); + Set sent = new HashSet<>(); sent.addAll(Arrays.asList(topic1data)); sent.addAll(Arrays.asList(topic2data)); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); kafkaParams.put("auto.offset.reset", "smallest"); @@ -95,17 +95,17 @@ public void testKafkaStream() throws InterruptedException { // Make sure you can get offset ranges from the rdd new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd) { OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); offsetRanges.set(offsets); - Assert.assertEquals(offsets[0].topic(), topic1); + Assert.assertEquals(topic1, offsets[0].topic()); return rdd; } } ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -119,10 +119,10 @@ public String call(Tuple2 kv) throws Exception { StringDecoder.class, String.class, kafkaParams, - topicOffsetToMap(topic2, (long) 0), + topicOffsetToMap(topic2, 0L), new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -133,7 +133,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception unifiedStream.foreachRDD( new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { + public Void call(JavaRDD rdd) { result.addAll(rdd.collect()); for (OffsetRange o : offsetRanges.get()) { System.out.println( @@ -155,14 +155,14 @@ public Void call(JavaRDD rdd) throws Exception { ssc.stop(); } - private HashSet topicToSet(String topic) { - HashSet topicSet = new HashSet(); + private static Set topicToSet(String topic) { + Set topicSet = new HashSet<>(); topicSet.add(topic); return topicSet; } - private HashMap topicOffsetToMap(String topic, Long offsetToStart) { - HashMap topicMap = new HashMap(); + private static Map topicOffsetToMap(String topic, Long offsetToStart) { + Map topicMap = new HashMap<>(); topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); return topicMap; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index a9dc6e50613c..afcc6cfccd39 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.HashMap; +import java.util.Map; import scala.Tuple2; @@ -66,10 +67,10 @@ public void testKafkaRDD() throws InterruptedException { String topic1 = "topic1"; String topic2 = "topic2"; - String[] topic1data = createTopicAndSendData(topic1); - String[] topic2data = createTopicAndSendData(topic2); + createTopicAndSendData(topic1); + createTopicAndSendData(topic2); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); OffsetRange[] offsetRanges = { @@ -77,8 +78,8 @@ public void testKafkaRDD() throws InterruptedException { OffsetRange.create(topic2, 0, 0, 1) }; - HashMap emptyLeaders = new HashMap(); - HashMap leaders = new HashMap(); + Map emptyLeaders = new HashMap<>(); + Map leaders = new HashMap<>(); String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); leaders.put(new TopicAndPartition(topic1, 0), broker); @@ -95,7 +96,7 @@ public void testKafkaRDD() throws InterruptedException { ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -113,7 +114,7 @@ public String call(Tuple2 kv) throws Exception { emptyLeaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -131,7 +132,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception leaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index e4c659215b76..1e69de46cd35 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -67,10 +67,10 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { String topic = "topic1"; - HashMap topics = new HashMap(); + Map topics = new HashMap<>(); topics.put(topic, 1); - HashMap sent = new HashMap(); + Map sent = new HashMap<>(); sent.put("a", 5); sent.put("b", 3); sent.put("c", 10); @@ -78,7 +78,7 @@ public void testKafkaStream() throws InterruptedException { kafkaTestUtils.createTopic(topic); kafkaTestUtils.sendMessages(topic, sent); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -97,7 +97,7 @@ public void testKafkaStream() throws InterruptedException { JavaDStream words = stream.map( new Function, String>() { @Override - public String call(Tuple2 tuple2) throws Exception { + public String call(Tuple2 tuple2) { return tuple2._2(); } } @@ -106,7 +106,7 @@ public String call(Tuple2 tuple2) throws Exception { words.countByValue().foreachRDD( new Function, Void>() { @Override - public Void call(JavaPairRDD rdd) throws Exception { + public Void call(JavaPairRDD rdd) { List> ret = rdd.collect(); for (Tuple2 r : ret) { if (result.containsKey(r._1())) { @@ -130,8 +130,8 @@ public Void call(JavaPairRDD rdd) throws Exception { Thread.sleep(200); } Assert.assertEquals(sent.size(), result.size()); - for (String k : sent.keySet()) { - Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + for (Map.Entry e : sent.entrySet()) { + Assert.assertEquals(e.getValue().intValue(), result.get(e.getKey()).intValue()); } } } diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java index e46b4e5c7531..26ec8af455bc 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.streaming.twitter; -import java.util.Arrays; - import org.junit.Test; import twitter4j.Status; import twitter4j.auth.Authorization; @@ -30,7 +28,7 @@ public class JavaTwitterStreamSuite extends LocalJavaStreamingContext { @Test public void testTwitterStream() { - String[] filters = (String[])Arrays.asList("filter1", "filter2").toArray(); + String[] filters = { "filter1", "filter2" }; Authorization auth = NullAuthorization.getInstance(); // tests the API, does not actually test data receiving diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 729bc0459ce5..14975265ab2c 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -77,7 +77,7 @@ public void call(String s) { public void foreach() { foreachCalls = 0; JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach((x) -> foreachCalls++); + rdd.foreach(x -> foreachCalls++); Assert.assertEquals(2, foreachCalls); } @@ -180,7 +180,7 @@ public void map() { JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)) .cache(); pairs.collect(); - JavaRDD strings = rdd.map(x -> x.toString()).cache(); + JavaRDD strings = rdd.map(Object::toString).cache(); strings.collect(); } @@ -195,7 +195,9 @@ public void flatMap() { JavaPairRDD pairs = rdd.flatMapToPair(s -> { List> pairs2 = new LinkedList<>(); - for (String word : s.split(" ")) pairs2.add(new Tuple2<>(word, word)); + for (String word : s.split(" ")) { + pairs2.add(new Tuple2<>(word, word)); + } return pairs2; }); @@ -204,11 +206,12 @@ public void flatMap() { JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { List lengths = new LinkedList<>(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); + } return lengths; }); - Double x = doubles.first(); Assert.assertEquals(5.0, doubles.first(), 0.01); Assert.assertEquals(11, pairs.count()); } @@ -228,7 +231,7 @@ public void mapsFromPairsToPairs() { swapped.collect(); // There was never a bug here, but it's worth testing: - pairRDD.map(item -> item.swap()).collect(); + pairRDD.map(Tuple2::swap).collect(); } @Test @@ -282,11 +285,11 @@ public void zipPartitions() { FlatMapFunction2, Iterator, Integer> sizesFn = (Iterator i, Iterator s) -> { int sizeI = 0; - int sizeS = 0; while (i.hasNext()) { sizeI += 1; i.next(); } + int sizeS = 0; while (s.hasNext()) { sizeS += 1; s.next(); @@ -301,30 +304,31 @@ public void zipPartitions() { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(x -> intAccum.add(x)); + Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(intAccum::add); Assert.assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + Accumulator doubleAccum = sc.doubleAccumulator(10.0); rdd.foreach(x -> doubleAccum.add((double) x)); Assert.assertEquals((Double) 25.0, doubleAccum.value()); // Try a custom accumulator type AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + @Override public Float addInPlace(Float r, Float t) { return r + t; } - + @Override public Float addAccumulator(Float r, Float t) { return r + t; } - + @Override public Float zero(Float initialValue) { return 0.0f; } }; - final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); + Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(x -> floatAccum.add((float) x)); Assert.assertEquals((Float) 25.0f, floatAccum.value()); @@ -336,7 +340,7 @@ public Float zero(Float initialValue) { @Test public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(x -> x.toString()).collect(); + List> s = rdd.keyBy(Object::toString).collect(); Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @@ -349,7 +353,7 @@ public void mapOnPairRDD() { JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), + new Tuple2<>(1, 1), new Tuple2<>(0, 2), new Tuple2<>(1, 3), new Tuple2<>(0, 4)), rdd3.collect()); @@ -361,7 +365,7 @@ public void collectPartitions() { JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); - List[] parts = rdd1.collectPartitions(new int[]{0}); + List[] parts = rdd1.collectPartitions(new int[]{0}); Assert.assertEquals(Arrays.asList(1, 2), parts[0]); parts = rdd1.collectPartitions(new int[]{1, 2}); @@ -371,19 +375,19 @@ public void collectPartitions() { Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[]{0})[0]); - parts = rdd2.collectPartitions(new int[]{1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts[0]); + List>[] parts2 = rdd2.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), - parts[1]); + parts2[1]); } @Test public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 - JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[]{1})); + JavaRDD rdd = sc.parallelize(Arrays.asList(1)); JavaPairRDD pairRDD = rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); pairRDD.collect(); // Works fine - Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException + pairRDD.collectAsMap(); // Used to crash with ClassCastException } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index bf693c7c393f..7b50aad4ad49 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -83,7 +84,7 @@ public void setAge(int age) { @Test public void applySchema() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -95,12 +96,13 @@ public void applySchema() { JavaRDD rowRDD = javaCtx.parallelize(personList).map( new Function() { + @Override public Row call(Person person) throws Exception { return RowFactory.create(person.getName(), person.getAge()); } }); - List fields = new ArrayList(2); + List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); @@ -118,7 +120,7 @@ public Row call(Person person) throws Exception { @Test public void dataFrameRDDOperations() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -129,27 +131,28 @@ public void dataFrameRDDOperations() { personList.add(person2); JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + new Function() { + @Override + public Row call(Person person) { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList<>(2); + fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - + @Override public String call(Row row) { - return row.getString(0) + "_" + row.get(1).toString(); + return row.getString(0) + "_" + row.get(1); } }).collect(); - List expected = new ArrayList(2); + List expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); @@ -165,7 +168,7 @@ public void applySchemaToJSON() { "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); - List fields = new ArrayList(7); + List fields = new ArrayList<>(7); fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); @@ -175,10 +178,10 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); StructType expectedSchema = DataTypes.createStructType(fields); - List expectedResult = new ArrayList(2); + List expectedResult = new ArrayList<>(2); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), + new BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, @@ -187,7 +190,7 @@ public void applySchemaToJSON() { "this is a simple string.")); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), + new BigDecimal("92233720368547758069"), false, 1.7976931348623157E305, 11, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4867cebf5328..d981ce947f43 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -61,7 +61,7 @@ public void tearDown() { @Test public void testExecution() { DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } /** @@ -119,7 +119,7 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; - private Integer[] b = new Integer[]{0, 1}; + private Integer[] b = { 0, 1 }; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); @@ -161,7 +161,7 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("d")); Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. Seq result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); @@ -180,7 +180,8 @@ public void testCreateDataFrameFromJavaBeans() { } } - private static Comparator CrosstabRowComparator = new Comparator() { + private static final Comparator crosstabRowComparator = new Comparator() { + @Override public int compare(Row row1, Row row2) { String item1 = row1.getString(0); String item2 = row2.getString(0); @@ -193,16 +194,16 @@ public void testCrosstab() { DataFrame df = context.table("testData2"); DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); - Assert.assertEquals(columnNames[0], "a_b"); - Assert.assertEquals(columnNames[1], "1"); - Assert.assertEquals(columnNames[2], "2"); + Assert.assertEquals("a_b", columnNames[0]); + Assert.assertEquals("1", columnNames[1]); + Assert.assertEquals("2", columnNames[2]); Row[] rows = crosstab.collect(); - Arrays.sort(rows, CrosstabRowComparator); + Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); - Assert.assertEquals(row.getLong(1), 1L); - Assert.assertEquals(row.getLong(2), 1L); + Assert.assertEquals(1L, row.getLong(1)); + Assert.assertEquals(1L, row.getLong(2)); count++; } } @@ -210,7 +211,7 @@ public void testCrosstab() { @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); - String[] cols = new String[]{"a"}; + String[] cols = {"a"}; DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } @@ -219,14 +220,14 @@ public void testFrequentItems() { public void testCorrelation() { DataFrame df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); - Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { DataFrame df = context.table("testData2"); Double result = df.stat().cov("a", "b"); - Assert.assertTrue(Math.abs(result) < 1e-6); + Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test @@ -234,7 +235,7 @@ public void testSampleBy() { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; Assert.assertArrayEquals(expected, actual); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 4ce1d1dddb26..3ab4db2a035d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; @@ -52,12 +53,12 @@ public void setUp() { shortValue = (short)32767; intValue = 2147483647; longValue = 9223372036854775807L; - floatValue = (float)3.4028235E38; + floatValue = 3.4028235E38f; doubleValue = 1.7976931348623157E308; decimalValue = new BigDecimal("1.7976931348623157E328"); booleanValue = true; stringValue = "this is a string"; - binaryValue = stringValue.getBytes(); + binaryValue = stringValue.getBytes(StandardCharsets.UTF_8); dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -123,8 +124,8 @@ public void constructSimpleRow() { Assert.assertEquals(binaryValue, simpleRow.get(16)); Assert.assertEquals(dateValue, simpleRow.get(17)); Assert.assertEquals(timestampValue, simpleRow.get(18)); - Assert.assertEquals(true, simpleRow.isNullAt(19)); - Assert.assertEquals(null, simpleRow.get(19)); + Assert.assertTrue(simpleRow.isNullAt(19)); + Assert.assertNull(simpleRow.get(19)); } @Test @@ -134,7 +135,7 @@ public void constructComplexRow() { stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); // Simple map - Map simpleMap = new HashMap(); + Map simpleMap = new HashMap<>(); simpleMap.put(stringValue + " (1)", longValue); simpleMap.put(stringValue + " (2)", longValue - 1); simpleMap.put(stringValue + " (3)", longValue - 2); @@ -149,7 +150,7 @@ public void constructComplexRow() { List arrayOfRows = Arrays.asList(simpleStruct); // Complex map - Map, Row> complexMap = new HashMap, Row>(); + Map, Row> complexMap = new HashMap<>(); complexMap.put(arrayOfRows, simpleStruct); // Complex struct @@ -167,7 +168,7 @@ public void constructComplexRow() { Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); Assert.assertEquals(arrayOfRows, complexStruct.get(4)); Assert.assertEquals(complexMap, complexStruct.get(5)); - Assert.assertEquals(null, complexStruct.get(6)); + Assert.assertNull(complexStruct.get(6)); // A very complex row Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index bb02b58cca9b..4a78dca7fea6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -20,6 +20,7 @@ import java.io.Serializable; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -61,13 +62,13 @@ public void udf1Test() { sqlContext.udf().register("stringLengthTest", new UDF1() { @Override - public Integer call(String str) throws Exception { + public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); + Assert.assertEquals(4, result.getInt(0)); } @SuppressWarnings("unchecked") @@ -81,12 +82,12 @@ public void udf2Test() { sqlContext.udf().register("stringLengthTest", new UDF2() { @Override - public Integer call(String str1, String str2) throws Exception { + public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); + Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 6f9e7f68dc39..9e241f20987c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -44,7 +44,7 @@ public class JavaSaveLoadSuite { File path; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -64,7 +64,7 @@ public void setUp() throws IOException { path.delete(); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -82,7 +82,7 @@ public void tearDown() { @Test public void saveAndLoad() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); @@ -91,11 +91,11 @@ public void saveAndLoad() { @Test public void saveAndLoadWithSchema() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 019d8a30266e..b4bf9eef8fca 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -40,7 +40,7 @@ public class JavaDataFrameSuite { DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -52,7 +52,7 @@ public void setUp() throws IOException { hc = TestHive$.MODULE$; sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } @@ -71,7 +71,7 @@ public void tearDown() throws IOException { @Test public void saveTableAndQueryIt() { checkAnswer( - df.select(functions.avg("key").over( + df.select(avg("key").over( Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), hc.sql("SELECT avg(key) " + "OVER (PARTITION BY value " + @@ -95,7 +95,7 @@ public void testUDAF() { registeredUDAF.apply(col("value")), callUDF("mydoublesum", col("value"))); - List expectedResult = new ArrayList(); + List expectedResult = new ArrayList<>(); expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); checkAnswer( aggregatedDF, diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 4192155975c4..c8d272794d10 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -53,7 +53,7 @@ public class JavaMetastoreDataSourcesSuite { FileSystem fs; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -77,7 +77,7 @@ public void setUp() throws IOException { fs.delete(hiveManagedPath, true); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -97,7 +97,7 @@ public void tearDown() throws IOException { @Test public void saveExternalTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -120,7 +120,7 @@ public void saveExternalTableAndQueryIt() { @Test public void saveExternalTableWithSchemaAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -132,7 +132,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { sqlContext.sql("SELECT * FROM javaSavedTable"), df.collectAsList()); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = @@ -148,7 +148,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { @Test public void saveTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); df.write() .format("org.apache.spark.sql.json") .mode(SaveMode.Append) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index e0718f73aa13..c5217149224e 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -18,24 +18,22 @@ package org.apache.spark.streaming; import java.io.*; -import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; +import scala.Tuple2; + +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; -import scala.Tuple2; - import org.junit.Assert; -import static org.junit.Assert.*; import org.junit.Test; import com.google.common.base.Optional; -import com.google.common.collect.Lists; import com.google.common.io.Files; import com.google.common.collect.Sets; @@ -54,14 +52,14 @@ // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite extends LocalJavaStreamingContext implements Serializable { - public void equalIterator(Iterator a, Iterator b) { + public static void equalIterator(Iterator a, Iterator b) { while (a.hasNext() && b.hasNext()) { Assert.assertEquals(a.next(), b.next()); } Assert.assertEquals(a.hasNext(), b.hasNext()); } - public void equalIterable(Iterable a, Iterable b) { + public static void equalIterable(Iterable a, Iterable b) { equalIterator(a.iterator(), b.iterator()); } @@ -74,14 +72,14 @@ public void testInitialization() { @Test public void testContextState() { List> inputData = Arrays.asList(Arrays.asList(1, 2, 3, 4)); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaTestUtils.attachTestOutputStream(stream); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); ssc.start(); - Assert.assertTrue(ssc.getState() == StreamingContextState.ACTIVE); + Assert.assertEquals(StreamingContextState.ACTIVE, ssc.getState()); ssc.stop(); - Assert.assertTrue(ssc.getState() == StreamingContextState.STOPPED); + Assert.assertEquals(StreamingContextState.STOPPED, ssc.getState()); } @SuppressWarnings("unchecked") @@ -118,7 +116,7 @@ public void testMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -180,7 +178,7 @@ public void testWindowWithSlideDuration() { public void testFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("giants"), @@ -189,7 +187,7 @@ public void testFilter() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream filtered = stream.filter(new Function() { @Override - public Boolean call(String s) throws Exception { + public Boolean call(String s) { return s.contains("a"); } }); @@ -243,11 +241,11 @@ public void testRepartitionFewerPartitions() { public void testGlom() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); + Arrays.asList(Arrays.asList("yankees", "red sox"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream> glommed = stream.glom(); @@ -262,22 +260,22 @@ public void testGlom() { public void testMapPartitions() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); + Arrays.asList("YANKEESRED SOX")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override public Iterable call(Iterator in) { - String out = ""; + StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Lists.newArrayList(out); + return Arrays.asList(out.toString()); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -286,16 +284,16 @@ public Iterable call(Iterator in) { Assert.assertEquals(expected, result); } - private class IntegerSum implements Function2 { + private static class IntegerSum implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 + i2; } } - private class IntegerDifference implements Function2 { + private static class IntegerDifference implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 - i2; } } @@ -347,13 +345,13 @@ private void testReduceByWindow(boolean withInverse) { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = null; + JavaDStream reducedWindowed; if (withInverse) { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); + new IntegerDifference(), new Duration(2000), new Duration(1000)); } else { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); @@ -378,11 +376,11 @@ public void testQueueStream() { Arrays.asList(7,8,9)); JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sparkContext().parallelize(Arrays.asList(1, 2, 3)); - JavaRDD rdd2 = ssc.sparkContext().parallelize(Arrays.asList(4, 5, 6)); - JavaRDD rdd3 = ssc.sparkContext().parallelize(Arrays.asList(7,8,9)); + JavaRDD rdd1 = jsc.parallelize(Arrays.asList(1, 2, 3)); + JavaRDD rdd2 = jsc.parallelize(Arrays.asList(4, 5, 6)); + JavaRDD rdd3 = jsc.parallelize(Arrays.asList(7,8,9)); - LinkedList> rdds = Lists.newLinkedList(); + Queue> rdds = new LinkedList<>(); rdds.add(rdd1); rdds.add(rdd2); rdds.add(rdd3); @@ -410,10 +408,10 @@ public void testTransform() { JavaDStream transformed = stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return in.map(new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i + 2; } }); @@ -435,70 +433,70 @@ public void testVariousTransform() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); List>> pairInputData = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); - JavaDStream transformed1 = stream.transform( + stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return null; } } ); - JavaDStream transformed2 = stream.transform( + stream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaRDD in, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream.transformToPair( + stream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in) throws Exception { + @Override public JavaPairRDD call(JavaRDD in) { return null; } } ); - JavaPairDStream transformed4 = stream.transformToPair( + stream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaRDD in, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream.transform( + pairStream.transform( new Function, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in) throws Exception { + @Override public JavaRDD call(JavaPairRDD in) { return null; } } ); - JavaDStream pairTransformed2 = pairStream.transform( + pairStream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaPairRDD in, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream.transformToPair( + pairStream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream.transformToPair( + pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in, Time time) { return null; } } @@ -511,32 +509,32 @@ public JavaRDD call(JavaRDD in) throws Exception { public void testTransformWith() { List>> stringStringKVStream1 = Arrays.asList( Arrays.asList( - new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), + new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), Arrays.asList( - new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( Arrays.asList( - new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), Arrays.asList( - new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Sets.newHashSet( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( ssc, stringStringKVStream1, 1); @@ -552,14 +550,12 @@ public void testTransformWith() { JavaPairRDD, JavaPairRDD, Time, - JavaPairRDD> - >() { + JavaPairRDD>>() { @Override public JavaPairRDD> call( JavaPairRDD rdd1, JavaPairRDD rdd2, - Time time - ) throws Exception { + Time time) { return rdd1.join(rdd2); } } @@ -567,9 +563,9 @@ public JavaPairRDD> call( JavaTestUtils.attachTestOutputStream(joined); List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); + List>>> unorderedResult = new ArrayList<>(); for (List>> res: result) { - unorderedResult.add(Sets.newHashSet(res)); + unorderedResult.add(Sets.newHashSet(res)); } Assert.assertEquals(expected, unorderedResult); @@ -587,89 +583,89 @@ public void testVariousTransformWith() { JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); List>> pairInputData1 = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); List>> pairInputData2 = - Arrays.asList(Arrays.asList(new Tuple2(1.0, 'x'))); + Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); - JavaDStream transformed1 = stream1.transformWith( + stream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream transformed2 = stream1.transformWith( + stream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream1.transformWithToPair( + stream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed4 = stream1.transformWithToPair( + stream1.transformWithToPair( pairStream1, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream1.transformWith( + pairStream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed2_ = pairStream1.transformWith( + pairStream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( pairStream2, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } @@ -690,13 +686,13 @@ public void testStreamingContextTransform(){ ); List>> pairStream1input = Arrays.asList( - Arrays.asList(new Tuple2(1, "x")), - Arrays.asList(new Tuple2(2, "y")) + Arrays.asList(new Tuple2<>(1, "x")), + Arrays.asList(new Tuple2<>(2, "y")) ); List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>(1, new Tuple2(1, "x"))), - Arrays.asList(new Tuple2>(2, new Tuple2(2, "y"))) + Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), + Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) ); JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); @@ -707,7 +703,7 @@ public void testStreamingContextTransform(){ List> listOfDStreams1 = Arrays.>asList(stream1, stream2); // This is just to test whether this transform to JavaStream compiles - JavaDStream transformed1 = ssc.transform( + ssc.transform( listOfDStreams1, new Function2>, Time, JavaRDD>() { @Override @@ -733,8 +729,8 @@ public JavaPairRDD> call(List> listO JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); PairFunction mapToTuple = new PairFunction() { @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i); + public Tuple2 call(Integer i) { + return new Tuple2<>(i, i); } }; return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); @@ -763,7 +759,7 @@ public void testFlatMap() { JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); + return Arrays.asList(x.split("(?!^)")); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -782,39 +778,39 @@ public void testPairFlatMap() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), + new Tuple2<>(6, "g"), + new Tuple2<>(6, "i"), + new Tuple2<>(6, "a"), + new Tuple2<>(6, "n"), + new Tuple2<>(6, "t"), + new Tuple2<>(6, "s")), Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "o"), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "g"), + new Tuple2<>(7, "e"), + new Tuple2<>(7, "r"), + new Tuple2<>(7, "s")), Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); + new Tuple2<>(9, "a"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "h"), + new Tuple2<>(9, "l"), + new Tuple2<>(9, "e"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "i"), + new Tuple2<>(9, "c"), + new Tuple2<>(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); + public Iterable> call(String in) { + List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); + out.add(new Tuple2<>(in.length(), letter)); } return out; } @@ -859,13 +855,13 @@ public void testUnion() { */ public static void assertOrderInvariantEquals( List> expected, List> actual) { - List> expectedSets = new ArrayList>(); + List> expectedSets = new ArrayList<>(); for (List list: expected) { - expectedSets.add(Collections.unmodifiableSet(new HashSet(list))); + expectedSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } - List> actualSets = new ArrayList>(); + List> actualSets = new ArrayList<>(); for (List list: actual) { - actualSets.add(Collections.unmodifiableSet(new HashSet(list))); + actualSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } Assert.assertEquals(expectedSets, actualSets); } @@ -877,25 +873,25 @@ public static void assertOrderInvariantEquals( public void testPairFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); + Arrays.asList(new Tuple2<>("giants", 6)), + Arrays.asList(new Tuple2<>("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = stream.mapToPair( new PairFunction() { @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); + public Tuple2 call(String in) { + return new Tuple2<>(in, in.length()); } }); JavaPairDStream filtered = pairStream.filter( new Function, Boolean>() { @Override - public Boolean call(Tuple2 in) throws Exception { + public Boolean call(Tuple2 in) { return in._1().contains("a"); } }); @@ -906,28 +902,28 @@ public Boolean call(Tuple2 in) throws Exception { } @SuppressWarnings("unchecked") - private List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); + private final List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "yankees"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "rangers"), + new Tuple2<>("new york", "islanders"))); @SuppressWarnings("unchecked") - private List>> stringIntKVStream = Arrays.asList( + private final List>> stringIntKVStream = Arrays.asList( Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), + new Tuple2<>("california", 1), + new Tuple2<>("california", 3), + new Tuple2<>("new york", 4), + new Tuple2<>("new york", 1)), Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); + new Tuple2<>("california", 5), + new Tuple2<>("california", 5), + new Tuple2<>("new york", 3), + new Tuple2<>("new york", 1))); @SuppressWarnings("unchecked") @Test @@ -936,22 +932,22 @@ public void testPairMap() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @Override - public Tuple2 call(Tuple2 in) throws Exception { + public Tuple2 call(Tuple2 in) { return in.swap(); } }); @@ -969,23 +965,23 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) throws Exception { - LinkedList> out = new LinkedList>(); + public Iterable> call(Iterator> in) { + List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); @@ -1014,7 +1010,7 @@ public void testPairMap2() { // Maps pair -> single JavaDStream reversed = pairStream.map( new Function, Integer>() { @Override - public Integer call(Tuple2 in) throws Exception { + public Integer call(Tuple2 in) { return in._2(); } }); @@ -1030,23 +1026,23 @@ public Integer call(Tuple2 in) throws Exception { public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2)), + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2)), Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2))); + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o")), + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o")), Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o"))); + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -1054,10 +1050,10 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) throws Exception { - List> out = new LinkedList>(); + public Iterable> call(Tuple2 in) { + List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { - out.add(new Tuple2(in._2(), s.toString())); + out.add(new Tuple2<>(in._2(), s.toString())); } return out; } @@ -1075,11 +1071,11 @@ public void testPairGroupByKey() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + new Tuple2<>("california", Arrays.asList("dodgers", "giants")), + new Tuple2<>("new york", Arrays.asList("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + new Tuple2<>("california", Arrays.asList("sharks", "ducks")), + new Tuple2<>("new york", Arrays.asList("rangers", "islanders")))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1111,11 +1107,11 @@ public void testPairReduceByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1136,20 +1132,20 @@ public void testCombineByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream combined = pairStream.combineByKey( + JavaPairDStream combined = pairStream.combineByKey( new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i; } }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); @@ -1170,13 +1166,13 @@ public void testCountByValue() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("moon", 1L)), Arrays.asList( - new Tuple2("hello", 1L))); + new Tuple2<>("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1193,16 +1189,16 @@ public void testGroupByKeyAndWindow() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3)), - new Tuple2>("new york", Arrays.asList(1, 4)) + new Tuple2<>("california", Arrays.asList(1, 3)), + new Tuple2<>("new york", Arrays.asList(1, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), - new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) + new Tuple2<>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 1, 3, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(5, 5)), - new Tuple2>("new york", Arrays.asList(1, 3)) + new Tuple2<>("california", Arrays.asList(5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 3)) ) ); @@ -1220,16 +1216,16 @@ public void testGroupByKeyAndWindow() { } } - private HashSet>> convert(List>> listOfTuples) { - List>> newListOfTuples = new ArrayList>>(); + private static Set>> convert(List>> listOfTuples) { + List>> newListOfTuples = new ArrayList<>(); for (Tuple2> tuple: listOfTuples) { newListOfTuples.add(convert(tuple)); } - return new HashSet>>(newListOfTuples); + return new HashSet<>(newListOfTuples); } - private Tuple2> convert(Tuple2> tuple) { - return new Tuple2>(tuple._1(), new HashSet(tuple._2())); + private static Tuple2> convert(Tuple2> tuple) { + return new Tuple2<>(tuple._1(), new HashSet<>(tuple._2())); } @SuppressWarnings("unchecked") @@ -1238,12 +1234,12 @@ public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1262,12 +1258,12 @@ public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1278,10 +1274,10 @@ public void testUpdateStateByKey() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1298,19 +1294,19 @@ public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; List> initial = Arrays.asList ( - new Tuple2 ("california", 1), - new Tuple2 ("new york", 2)); + new Tuple2<>("california", 1), + new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 5), - new Tuple2("new york", 7)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11))); + Arrays.asList(new Tuple2<>("california", 5), + new Tuple2<>("new york", 7)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1321,10 +1317,10 @@ public void testUpdateStateByKeyWithInitial() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1341,19 +1337,19 @@ public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1370,15 +1366,15 @@ public void testCountByValueAndWindow() { List>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("world", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 2L), + new Tuple2<>("world", 1L), + new Tuple2<>("moon", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("moon", 1L))); + new Tuple2<>("hello", 2L), + new Tuple2<>("moon", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1386,7 +1382,7 @@ public void testCountByValueAndWindow() { stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - List>> unorderedResult = Lists.newArrayList(); + List>> unorderedResult = new ArrayList<>(); for (List> res: result) { unorderedResult.add(Sets.newHashSet(res)); } @@ -1399,27 +1395,27 @@ public void testCountByValueAndWindow() { public void testPairTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5)), + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5)), Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5))); + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1428,7 +1424,7 @@ public void testPairTransform() { JavaPairDStream sorted = pairStream.transformToPair( new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD in) throws Exception { + public JavaPairRDD call(JavaPairRDD in) { return in.sortByKey(); } }); @@ -1444,15 +1440,15 @@ public JavaPairRDD call(JavaPairRDD in) thro public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List> expected = Arrays.asList( Arrays.asList(3,1,4,2), @@ -1465,11 +1461,11 @@ public void testPairToNormalRDDTransform() { JavaDStream firstParts = pairStream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD in) throws Exception { + public JavaRDD call(JavaPairRDD in) { return in.map(new Function, Integer>() { @Override - public Integer call(Tuple2 in) { - return in._1(); + public Integer call(Tuple2 in2) { + return in2._1(); } }); } @@ -1487,14 +1483,14 @@ public void testMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); + Arrays.asList(new Tuple2<>("california", "DODGERS"), + new Tuple2<>("california", "GIANTS"), + new Tuple2<>("new york", "YANKEES"), + new Tuple2<>("new york", "METS")), + Arrays.asList(new Tuple2<>("california", "SHARKS"), + new Tuple2<>("california", "DUCKS"), + new Tuple2<>("new york", "RANGERS"), + new Tuple2<>("new york", "ISLANDERS"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1502,8 +1498,8 @@ public void testMapValues() { JavaPairDStream mapped = pairStream.mapValues(new Function() { @Override - public String call(String s) throws Exception { - return s.toUpperCase(); + public String call(String s) { + return s.toUpperCase(Locale.ENGLISH); } }); @@ -1519,22 +1515,22 @@ public void testFlatMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); + Arrays.asList(new Tuple2<>("california", "dodgers1"), + new Tuple2<>("california", "dodgers2"), + new Tuple2<>("california", "giants1"), + new Tuple2<>("california", "giants2"), + new Tuple2<>("new york", "yankees1"), + new Tuple2<>("new york", "yankees2"), + new Tuple2<>("new york", "mets1"), + new Tuple2<>("new york", "mets2")), + Arrays.asList(new Tuple2<>("california", "sharks1"), + new Tuple2<>("california", "sharks2"), + new Tuple2<>("california", "ducks1"), + new Tuple2<>("california", "ducks2"), + new Tuple2<>("new york", "rangers1"), + new Tuple2<>("new york", "rangers2"), + new Tuple2<>("new york", "islanders1"), + new Tuple2<>("new york", "islanders2"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1545,7 +1541,7 @@ public void testFlatMapValues() { new Function>() { @Override public Iterable call(String in) { - List out = new ArrayList(); + List out = new ArrayList<>(); out.add(in + "1"); out.add(in + "2"); return out; @@ -1562,29 +1558,29 @@ public Iterable call(String in) { @Test public void testCoGroup() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List, List>>>> expected = Arrays.asList( Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("yankees"), Arrays.asList("mets")))), Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1620,29 +1616,29 @@ public void testCoGroup() { @Test public void testJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1664,13 +1660,13 @@ public void testJoin() { @Test public void testLeftOuterJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks") )); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks") )); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants") ), - Arrays.asList(new Tuple2("new york", "islanders") ) + Arrays.asList(new Tuple2<>("california", "giants") ), + Arrays.asList(new Tuple2<>("new york", "islanders") ) ); @@ -1713,7 +1709,7 @@ public void testCheckpointMasterRecovery() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1752,6 +1748,7 @@ public void testContextGetOrCreate() throws InterruptedException { // (used to detect the new context) final AtomicBoolean newContextCreated = new AtomicBoolean(false); Function0 creatingFunc = new Function0() { + @Override public JavaStreamingContext call() { newContextCreated.set(true); return new JavaStreamingContext(conf, Seconds.apply(1)); @@ -1765,20 +1762,20 @@ public JavaStreamingContext call() { newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); + new Configuration(), true); Assert.assertTrue("new context not created", newContextCreated.get()); ssc.stop(); newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); newContextCreated.set(false); JavaSparkContext sc = new JavaSparkContext(conf); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } @@ -1800,7 +1797,7 @@ public void testCheckpointofIndividualStream() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1818,29 +1815,26 @@ public Integer call(String s) throws Exception { // InputStream functionality is deferred to the existing Scala tests. @Test public void testSocketTextStream() { - JavaReceiverInputDStream test = ssc.socketTextStream("localhost", 12345); + ssc.socketTextStream("localhost", 12345); } @Test public void testSocketString() { - - class Converter implements Function> { - public Iterable call(InputStream in) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - return out; - } - } - - JavaDStream test = ssc.socketStream( + ssc.socketStream( "localhost", 12345, - new Converter(), + new Function>() { + @Override + public Iterable call(InputStream in) throws IOException { + List out = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) { + for (String line; (line = reader.readLine()) != null;) { + out.add(line); + } + } + return out; + } + }, StorageLevel.MEMORY_ONLY()); } @@ -1870,7 +1864,7 @@ public void testFileStream() throws IOException { TextInputFormat.class, new Function() { @Override - public Boolean call(Path v1) throws Exception { + public Boolean call(Path v1) { return Boolean.TRUE; } }, @@ -1879,7 +1873,7 @@ public Boolean call(Path v1) throws Exception { JavaDStream test = inputStream.map( new Function, String>() { @Override - public String call(Tuple2 v1) throws Exception { + public String call(Tuple2 v1) { return v1._2().toString(); } }); @@ -1892,19 +1886,15 @@ public String call(Tuple2 v1) throws Exception { @Test public void testRawSocketStream() { - JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345); + ssc.rawSocketStream("localhost", 12345); } - private List> fileTestPrepare(File testDir) throws IOException { + private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); Files.write("0\n", existingFile, Charset.forName("UTF-8")); - assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000); - - List> expected = Arrays.asList( - Arrays.asList("0") - ); - - return expected; + Assert.assertTrue(existingFile.setLastModified(1000)); + Assert.assertEquals(1000, existingFile.lastModified()); + return Arrays.asList(Arrays.asList("0")); } @SuppressWarnings("unchecked") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 1b0787fe69de..ec2bffd6a5b9 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -36,7 +36,6 @@ import java.io.Serializable; import java.net.ConnectException; import java.net.Socket; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; public class JavaReceiverAPISuite implements Serializable { @@ -64,16 +63,16 @@ public void testReceiver() throws InterruptedException { ssc.receiverStream(new JavaSocketReceiver("localhost", server.port())); JavaDStream mapped = input.map(new Function() { @Override - public String call(String v1) throws Exception { + public String call(String v1) { return v1 + "."; } }); mapped.foreachRDD(new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { - long count = rdd.count(); - dataCounter.addAndGet(count); - return null; + public Void call(JavaRDD rdd) { + long count = rdd.count(); + dataCounter.addAndGet(count); + return null; } }); @@ -83,7 +82,7 @@ public Void call(JavaRDD rdd) throws Exception { Thread.sleep(200); for (int i = 0; i < 6; i++) { - server.send("" + i + "\n"); // \n to make sure these are separate lines + server.send(i + "\n"); // \n to make sure these are separate lines Thread.sleep(100); } while (dataCounter.get() == 0 && System.currentTimeMillis() - startTime < timeout) { @@ -95,50 +94,49 @@ public Void call(JavaRDD rdd) throws Exception { server.stop(); } } -} -class JavaSocketReceiver extends Receiver { + private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + String host = null; + int port = -1; - public JavaSocketReceiver(String host_ , int port_) { - super(StorageLevel.MEMORY_AND_DISK()); - host = host_; - port = port_; - } + JavaSocketReceiver(String host_ , int port_) { + super(StorageLevel.MEMORY_AND_DISK()); + host = host_; + port = port_; + } - @Override - public void onStart() { - new Thread() { - @Override public void run() { - receive(); - } - }.start(); - } + @Override + public void onStart() { + new Thread() { + @Override public void run() { + receive(); + } + }.start(); + } - @Override - public void onStop() { - } + @Override + public void onStop() { + } - private void receive() { - Socket socket = null; - try { - socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + private void receive() { + try { + Socket socket = new Socket(host, port); + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + in.close(); + socket.close(); + } catch(ConnectException ce) { + ce.printStackTrace(); + restart("Could not connect", ce); + } catch(Throwable t) { + t.printStackTrace(); + restart("Error receiving data", t); } - in.close(); - socket.close(); - } catch(ConnectException ce) { - ce.printStackTrace(); - restart("Could not connect", ce); - } catch(Throwable t) { - t.printStackTrace(); - restart("Error receiving data", t); } } -} +} From f4a22808e03fa12bfe1bfc82cf713cfda7e063a9 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Sat, 12 Sep 2015 10:17:15 -0700 Subject: [PATCH 0606/1824] [SPARK-6548] Adding stddev to DataFrame functions Adding STDDEV support for DataFrame using 1-pass online /parallel algorithm to compute variance. Please review the code change. Author: JihongMa Author: Jihong MA Author: Jihong MA Author: Jihong MA Closes #6297 from JihongMA/SPARK-SQL. --- R/pkg/inst/tests/test_sparkSQL.R | 2 +- python/pyspark/sql/dataframe.py | 36 +-- .../catalyst/analysis/FunctionRegistry.scala | 3 + .../catalyst/analysis/HiveTypeCoercion.scala | 3 + .../spark/sql/catalyst/dsl/package.scala | 3 + .../expressions/aggregate/functions.scala | 143 ++++++++++ .../expressions/aggregate/utils.scala | 18 ++ .../sql/catalyst/expressions/aggregates.scala | 245 ++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 6 +- .../org/apache/spark/sql/GroupedData.scala | 39 +++ .../org/apache/spark/sql/functions.scala | 27 ++ .../apache/spark/sql/JavaDataFrameSuite.java | 1 + .../spark/sql/DataFrameAggregateSuite.scala | 33 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 42 ++- .../execution/AggregationQuerySuite.scala | 35 --- 16 files changed, 574 insertions(+), 64 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1ccfde59176f..98d4402d368e 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1147,7 +1147,7 @@ test_that("describe() and summarize() on a DataFrame", { stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") - expect_equal(collect(stats)[3, "age"], "5.5") + expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c5bf55791240..fb995fa3a76b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -653,25 +653,25 @@ def describe(self, *cols): guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe().show() - +-------+---+ - |summary|age| - +-------+---+ - | count| 2| - | mean|3.5| - | stddev|1.5| - | min| 2| - | max| 5| - +-------+---+ + +-------+------------------+ + |summary| age| + +-------+------------------+ + | count| 2| + | mean| 3.5| + | stddev|2.1213203435596424| + | min| 2| + | max| 5| + +-------+------------------+ >>> df.describe(['age', 'name']).show() - +-------+---+-----+ - |summary|age| name| - +-------+---+-----+ - | count| 2| 2| - | mean|3.5| null| - | stddev|1.5| null| - | min| 2|Alice| - | max| 5| Bob| - +-------+---+-----+ + +-------+------------------+-----+ + |summary| age| name| + +-------+------------------+-----+ + | count| 2| 2| + | mean| 3.5| null| + | stddev|2.1213203435596424| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+------------------+-----+ """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cd5a90d78815..11b4866bf264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,6 +168,9 @@ object FunctionRegistry { expression[Last]("last"), expression[Max]("max"), expression[Min]("min"), + expression[Stddev]("stddev"), + expression[StddevPop]("stddev_pop"), + expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), // string functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 87c11abbad49..87a3845b2d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -297,6 +297,9 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) + case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) + case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a7e3a4932765..699c4cc63d09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,6 +159,9 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def stddev(e: Expression): Expression = Stddev(e) + def stddev_pop(e: Expression): Expression = StddevPop(e) + def stddev_samp(e: Expression): Expression = StddevSamp(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index a73024d6adba..02cd0ac0db11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -249,6 +249,149 @@ case class Min(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = min } +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev" +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = false + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev_samp" +} + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + def isSample: Boolean + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select stddev(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = DoubleType + + private val preCount = AttributeReference("preCount", resultType)() + private val currentCount = AttributeReference("currentCount", resultType)() + private val preAvg = AttributeReference("preAvg", resultType)() + private val currentAvg = AttributeReference("currentAvg", resultType)() + private val currentMk = AttributeReference("currentMk", resultType)() + + override val bufferAttributes = preCount :: currentCount :: preAvg :: + currentAvg :: currentMk :: Nil + + override val initialValues = Seq( + /* preCount = */ Cast(Literal(0), resultType), + /* currentCount = */ Cast(Literal(0), resultType), + /* preAvg = */ Cast(Literal(0), resultType), + /* currentAvg = */ Cast(Literal(0), resultType), + /* currentMk = */ Cast(Literal(0), resultType) + ) + + override val updateExpressions = { + + // update average + // avg = avg + (value - avg)/count + def avgAdd: Expression = { + currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) + } + + // update sum of square of difference from mean + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + def mkAdd: Expression = { + val delta1 = Cast(child, resultType) - preAvg + val delta2 = Cast(child, resultType) - currentAvg + currentMk + (delta1 * delta2) + } + + Seq( + /* preCount = */ If(IsNull(child), preCount, currentCount), + /* currentCount = */ If(IsNull(child), currentCount, + Add(currentCount, Cast(Literal(1), resultType))), + /* preAvg = */ If(IsNull(child), preAvg, currentAvg), + /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), + /* currentMk = */ If(IsNull(child), currentMk, mkAdd) + ) + } + + override val mergeExpressions = { + + // count merge + def countMerge: Expression = { + currentCount.left + currentCount.right + } + + // average merge + def avgMerge: Expression = { + ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / + (preCount + currentCount.right) + } + + // update sum of square differences + def mkMerge: Expression = { + val avgDelta = currentAvg.right - preAvg + val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / + (preCount + currentCount.right) + + currentMk.left + currentMk.right + mkDelta + } + + Seq( + /* preCount = */ If(IsNull(currentCount.left), + Cast(Literal(0), resultType), currentCount.left), + /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, + If(IsNull(currentCount.right), currentCount.left, countMerge)), + /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), + /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, + If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), + /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, + If(IsNull(currentMk.right), currentMk.left, mkMerge)) + ) + } + + override val evaluateExpression = { + // when currentCount == 0, return null + // when currentCount == 1, return 0 + // when currentCount >1 + // stddev_samp = sqrt (currentMk/(currentCount -1)) + // stddev_pop = sqrt (currentMk/currentCount) + val varCol = { + if (isSample) { + currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) + } + else { + currentMk / currentCount + } + } + + If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(varCol), resultType))) + } +} + case class Sum(child: Expression) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a9549..ce3dddad87f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -85,6 +85,24 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Stddev(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Stddev(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevPop(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevPop(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevSamp(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevSamp(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Sum(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Sum(child), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5e8298aaaa9c..f1c47f39043c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -691,3 +691,248 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag result } } + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { + override def nullable: Boolean = true + override def dataType: DataType = DoubleType + + def isSample: Boolean + + override def asPartial: SplitEvaluation = { + val partialStd = Alias(ComputePartialStd(child), "PartialStddev")() + SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil) + } + + override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function stddev") + +} + +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV($child)" + override def isSample: Boolean = true +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_POP($child)" + override def isSample: Boolean = false +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_SAMP($child)" + override def isSample: Boolean = true +} + +case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = false + override def dataType: DataType = ArrayType(DoubleType) + override def toString: String = s"computePartialStddev($child)" + override def newInstance(): ComputePartialStdFunction = + new ComputePartialStdFunction(child, this) +} + +case class ComputePartialStdFunction ( + expr: Expression, + base: AggregateExpression1 +) extends AggregateFunction1 { + def this() = this(null, null) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private var partialCount: Long = 0L + + // the mean of data processed so far + private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update average based on this formula: + // avg = avg + (value - avg)/count + private def avgAddFunction (value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), partialAvg) + Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType))) + } + + // the sum of squares of difference from mean + private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update sum of square of difference from mean based on following formula: + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), prePartialAvg) + val delta2 = Subtract(Cast(value, computeType), partialAvg) + Add(partialMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + val prePartialAvg = partialAvg.copy() + partialCount += 1 + partialAvg.update(avgAddFunction(exprValue), input) + partialMk.update(mkAddFunction(exprValue, prePartialAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null), + partialAvg.eval(null), + partialMk.eval(null))) + } +} + +case class MergePartialStd( + child: Expression, + isSample: Boolean +) extends UnaryExpression with AggregateExpression1 { + def this() = this(null, false) // required for serialization + + override def children: Seq[Expression] = child:: Nil + override def nullable: Boolean = false + override def dataType: DataType = DoubleType + override def toString: String = s"MergePartialStd($child)" + override def newInstance(): MergePartialStdFunction = { + new MergePartialStdFunction(child, this, isSample) + } +} + +case class MergePartialStdFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + def this() = this (null, null, false) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private val combineCount = MutableLiteral(zero.eval(null), computeType) + private val combineAvg = MutableLiteral(zero.eval(null), computeType) + private val combineMk = MutableLiteral(zero.eval(null), computeType) + + private def avgUpdateFunction(preCount: Expression, + partialCount: Expression, + partialAvg: Expression): Expression = { + Divide(Add(Multiply(combineAvg, preCount), + Multiply(partialAvg, partialCount)), + Add(preCount, partialCount)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData] + + if (evaluatedExpr != null) { + val exprValue = evaluatedExpr.toArray(computeType) + val (partialCount, partialAvg, partialMk) = + (Literal.create(exprValue(0), computeType), + Literal.create(exprValue(1), computeType), + Literal.create(exprValue(2), computeType)) + + if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) { + val preCount = combineCount.copy() + combineCount.update(Add(combineCount, partialCount), input) + + val preAvg = combineAvg.copy() + val avgDelta = Subtract(partialAvg, preAvg) + val mkDelta = Multiply(Multiply(avgDelta, avgDelta), + Divide(Multiply(preCount, partialCount), + combineCount)) + + // update average based on following formula + // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount) + combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input) + + // update sum of square differences from mean based on following formula + // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount) + combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input) + } + } + } + + override def eval(input: InternalRow): Any = { + val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long] + + if (count == 0) null + else if (count < 2) zero.eval(null) + else { + // when total count > 2 + // stddev_samp = sqrt (combineMk/(combineCount -1)) + // stddev_pop = sqrt (combineMk/combineCount) + val varCol = { + if (isSample) { + Divide(combineMk, Cast(Literal(count - 1), computeType)) + } + else { + Divide(combineMk, Cast(Literal(count), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} + +case class StddevFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + + def this() = this(null, null, false) // Required for serialization + + private val computeType = DoubleType + private var curCount: Long = 0L + private val zero = Cast(Literal(0), computeType) + private val curAvg = MutableLiteral(zero.eval(null), computeType) + private val curMk = MutableLiteral(zero.eval(null), computeType) + + private def curAvgAddFunction(value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), curAvg) + Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType))) + } + private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), preAvg) + val delta2 = Subtract(Cast(value, computeType), curAvg) + Add(curMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val preAvg: MutableLiteral = curAvg.copy() + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + curCount += 1L + curAvg.update(curAvgAddFunction(exprValue), input) + curMk.update(curMkAddFunction(exprValue, preAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + if (curCount == 0) null + else if (curCount < 2) zero.eval(null) + else { + // when total count > 2, + // stddev_samp = sqrt(curMk/(curCount - 1)) + // stddev_pop = sqrt(curMk/curCount) + val varCol = { + if (isSample) { + Divide(curMk, Cast(Literal(curCount - 1), computeType)) + } + else { + Divide(curMk, Cast(Literal(curCount), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 791c10c3d7ce..1a687b2374f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1288,15 +1288,11 @@ class DataFrame private[sql]( @scala.annotation.varargs def describe(cols: String*): DataFrame = { - // TODO: Add stddev as an expression, and remove it from here. - def stddevExpr(expr: Expression): Expression = - Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) - // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> stddevExpr, + "stddev" -> Stddev, "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index ee31d83cce42..102b802ad0a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -124,6 +124,9 @@ class GroupedData protected[sql]( case "avg" | "average" | "mean" => Average case "max" => Max case "min" => Min + case "stddev" => Stddev + case "stddev_pop" => StddevPop + case "stddev_samp" => StddevSamp case "sum" => Sum case "count" | "size" => // Turn count(*) into count(1) @@ -283,6 +286,42 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Min) } + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Stddev) + } + + /** + * Compute the population standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_pop(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevPop) + } + + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_samp(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevSamp) + } + /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 435e6319a64c..60d9c509104d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -294,6 +294,33 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the unbiased sample standard deviation + * of the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(e: Column): Column = Stddev(e.expr) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(e: Column): Column = StddevPop(e.expr) + + /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(e: Column): Column = StddevSamp(e.expr) + /** * Aggregate function: returns the sum of all values in the expression. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index d981ce947f43..5f9abd4999ce 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -90,6 +90,7 @@ public void testVarargMethods() { df.groupBy().mean("key"); df.groupBy().max("key"); df.groupBy().min("key"); + df.groupBy().stddev("key"); df.groupBy().sum("key"); // Varargs in column expressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index c0950b09b14a..f5ef9ffd7f4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -175,6 +175,39 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(0, null)) } + test("stddev") { + val testData2ADev = math.sqrt(4/5.0) + + checkAnswer( + testData2.agg(stddev('a)), + Row(testData2ADev)) + + checkAnswer( + testData2.agg(stddev_pop('a)), + Row(math.sqrt(4/6.0))) + + checkAnswer( + testData2.agg(stddev_samp('a)), + Row(testData2ADev)) + } + + test("zero stddev") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() == 0) + + checkAnswer( + emptyTableData.agg(stddev('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_pop('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_samp('a)), + Row(null)) + } + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dbed4fc24714..c167999af580 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -436,7 +436,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 664b7a1512fa..962b100b532c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -328,6 +328,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) + // STDDEV + testCodeGen( + "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) + testCodeGen( + "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", + Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -348,8 +355,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", + Row(null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -515,8 +522,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 6, 3) + sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(1, 3, 2, 1, 6, 3) ) } @@ -722,6 +729,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("stddev") { + checkAnswer( + sql("SELECT STDDEV(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev_pop") { + checkAnswer( + sql("SELECT STDDEV_POP(a) FROM testData2"), + Row(math.sqrt(4/6.0)) + ) + } + + test("stddev_samp") { + checkAnswer( + sql("SELECT STDDEV_SAMP(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev agg") { + checkAnswer( + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index b126ec455fc6..a73b1bd52c09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -507,41 +507,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te }.getMessage assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } - - // TODO: once we support Hive UDAF in the new interface, - // we can remove the following two tests. - withSQLConf("spark.sql.useAggregate2" -> "true") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | mydoublesum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - - // This will fall back to the old aggregate - val newAggregateOperators = sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.SortBasedAggregate => agg - case agg: aggregate.TungstenAggregate => agg - } - val message = - "We should fallback to the old aggregation code path if " + - "there is any aggregate function that cannot be converted to the new interface." - assert(newAggregateOperators.isEmpty, message) - } } } From b3a7480ab0821ab38f710de96e3ac4a13f62dbca Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 12 Sep 2015 16:23:55 -0700 Subject: [PATCH 0607/1824] [SPARK-10330] Add Scalastyle rule to require use of SparkHadoopUtil JobContext methods This is a followup to #8499 which adds a Scalastyle rule to mandate the use of SparkHadoopUtil's JobContext accessor methods and fixes the existing violations. Author: Josh Rosen Closes #8521 from JoshRosen/SPARK-10330-part2. --- .../src/main/scala/org/apache/spark/SparkContext.scala | 6 +++--- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 4 ++++ .../scala/org/apache/spark/rdd/PairRDDFunctions.scala | 8 +++++--- .../scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala | 2 +- core/src/test/scala/org/apache/spark/FileSuite.scala | 6 ++++-- .../org/apache/spark/examples/CassandraCQLTest.scala | 3 +++ .../org/apache/spark/examples/CassandraTest.scala | 2 ++ scalastyle-config.xml | 8 ++++++++ .../sql/execution/datasources/WriterContainer.scala | 8 ++++++-- .../sql/execution/datasources/json/JSONRelation.scala | 2 +- .../datasources/parquet/CatalystReadSupport.scala | 6 +++++- .../parquet/DirectParquetOutputCommitter.scala | 6 +++++- .../datasources/parquet/ParquetRelation.scala | 10 +++++++--- .../datasources/parquet/ParquetTypesConverter.scala | 6 +++++- .../org/apache/spark/sql/hive/orc/OrcRelation.scala | 4 ++-- 15 files changed, 61 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cbfe8bf31c3d..e27b3c496222 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -858,7 +858,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], @@ -910,7 +910,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -1092,7 +1092,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = job.getConfiguration + val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index f7723ef5bde4..a0b7365df900 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -192,7 +192,9 @@ class SparkHadoopUtil extends Logging { * while it's interface in Hadoop 2.+. */ def getConfigurationFromJobContext(context: JobContext): Configuration = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getConfiguration") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[Configuration] } @@ -204,7 +206,9 @@ class SparkHadoopUtil extends Logging { */ def getTaskAttemptIDFromTaskAttemptContext( context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getTaskAttemptID") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c59f0d4aa75a..199d79b811d6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -996,8 +996,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - job.getConfiguration.set("mapred.output.dir", path) - saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfiguration.set("mapred.output.dir", path) + saveAsNewAPIHadoopDataset(jobConfiguration) } /** @@ -1064,7 +1065,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableConfiguration(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 9babe56267e0..0228c54e0511 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -86,7 +86,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - job.getConfiguration + SparkHadoopUtil.get.getConfigurationFromJobContext(job) } private val jobTrackerId: String = { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 418763f4e5ff..fdb00aafc4a4 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{File, FileWriter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.PortableDataStream import org.apache.spark.storage.StorageLevel @@ -506,8 +507,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - job.getConfiguration.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") - randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") + randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index fa07c1e5017c..d1b9b8d398dd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -16,6 +16,7 @@ */ // scalastyle:off println + // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -81,6 +82,7 @@ object CassandraCQLTest { val job = new Job() job.setInputFormatClass(classOf[CqlPagingInputFormat]) + val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) ConfigHelper.setInputRpcPort(job.getConfiguration(), cPort) ConfigHelper.setInputColumnFamily(job.getConfiguration(), KeySpace, InputColumnFamily) @@ -135,3 +137,4 @@ object CassandraCQLTest { } } // scalastyle:on println +// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 2e56d24c60c3..1e679bfb5534 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -16,6 +16,7 @@ */ // scalastyle:off println +// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -130,6 +131,7 @@ object CassandraTest { } } // scalastyle:on println +// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 68fdb4141cf2..64a0c71bbef2 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -168,6 +168,14 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + ^getConfiguration$|^getTaskAttemptID$ + Instead of calling .getConfiguration() or .getTaskAttemptID() directly, + use SparkHadoopUtil's getConfigurationFromJobContext() and getTaskAttemptIDFromTaskAttemptContext() methods. + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 9a573db0c023..f8ef674ed29c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -47,7 +47,8 @@ private[sql] abstract class BaseWriterContainer( protected val dataSchema = relation.dataSchema - protected val serializableConf = new SerializableConfiguration(job.getConfiguration) + protected val serializableConf = + new SerializableConfiguration(SparkHadoopUtil.get.getConfigurationFromJobContext(job)) // This UUID is used to avoid output file name collision between different appending write jobs. // These jobs may belong to different SparkContext instances. Concrete data source implementations @@ -89,7 +90,8 @@ private[sql] abstract class BaseWriterContainer( // This UUID is sent to executor side together with the serialized `Configuration` object within // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate // unique task output files. - job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + SparkHadoopUtil.get.getConfigurationFromJobContext(job). + set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, @@ -182,7 +184,9 @@ private[sql] abstract class BaseWriterContainer( private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) this.taskId = new TaskID(this.jobId, true, splitId) + // scalastyle:off jobcontext this.taskAttemptId = new TaskAttemptID(taskId, attemptId) + // scalastyle:on jobcontext } private def setupConf(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 7a49157d9e72..8ee0127c3bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -81,7 +81,7 @@ private[sql] class JSONRelation( private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) val paths = inputPaths.map(_.getPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 5a8166fac541..8c819f1a48cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -72,7 +72,11 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with // Called before `prepareForRead()` when initializing Parquet record reader. override def init(context: InitContext): ReadContext = { - val conf = context.getConfiguration + val conf = { + // scalastyle:off jobcontext + context.getConfiguration + // scalastyle:on jobcontext + } // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst // schema of this file from its metadata. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 2c6b914328b6..de1fd0166ac5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -53,7 +53,11 @@ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: T override def setupTask(taskContext: TaskAttemptContext): Unit = {} override def commitJob(jobContext: JobContext) { - val configuration = ContextUtil.getConfiguration(jobContext) + val configuration = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(jobContext) + // scalastyle:on jobcontext + } val fileSystem = outputPath.getFileSystem(configuration) if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index c6bbc392cad4..953fcab12697 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -211,7 +211,11 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = ContextUtil.getConfiguration(job) + val conf = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(job) + // scalastyle:on jobcontext + } // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) @@ -528,7 +532,7 @@ private[sql] object ParquetRelation extends Logging { assumeBinaryIsString: Boolean, assumeInt96IsTimestamp: Boolean, followParquetFormatSpec: Boolean)(job: Job): Unit = { - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. @@ -572,7 +576,7 @@ private[sql] object ParquetRelation extends Logging { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } - overrideMinSplitSize(parquetBlockSize, job.getConfiguration) + overrideMinSplitSize(parquetBlockSize, SparkHadoopUtil.get.getConfigurationFromJobContext(job)) } private[parquet] def readSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 142301fe87cb..b647bb6116af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -123,7 +123,11 @@ private[parquet] object ParquetTypesConverter extends Logging { throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") } val job = new Job() - val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) + val conf = { + // scalastyle:off jobcontext + configuration.getOrElse(ContextUtil.getConfiguration(job)) + // scalastyle:on jobcontext + } val fs: FileSystem = origPath.getFileSystem(conf) if (fs == null) { throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 7e8910925995..d1f30e188eaf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -208,7 +208,7 @@ private[sql] class OrcRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { - job.getConfiguration match { + SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) case conf => @@ -289,7 +289,7 @@ private[orc] case class OrcTableScan( def execute(): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { From 1dc614b874badde0eee60def46fb47f608bc4759 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 13 Sep 2015 08:36:46 +0100 Subject: [PATCH 0608/1824] [SPARK-10222] [GRAPHX] [DOCS] More thoroughly deprecate Bagel in favor of GraphX Finish deprecating Bagel; remove reference to nonexistent example Author: Sean Owen Closes #8731 from srowen/SPARK-10222. --- .../src/main/scala/org/apache/spark/bagel/Bagel.scala | 6 ++++++ docs/bagel-programming-guide.md | 10 +--------- docs/index.md | 1 - pom.xml | 2 +- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index 4e6b7686f771..8399033ac61e 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -22,6 +22,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") object Bagel extends Logging { val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK @@ -270,18 +271,21 @@ object Bagel extends Logging { } } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Combiner[M, C] { def createCombiner(msg: M): C def mergeMsg(combiner: C, msg: M): C def mergeCombiners(a: C, b: C): C } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Aggregator[V, A] { def createAggregator(vert: V): A def mergeAggregators(a: A, b: A): A } /** Default combiner that simply appends messages together (i.e. performs no aggregation) */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { def createCombiner(msg: M): Array[M] = Array(msg) @@ -297,6 +301,7 @@ class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializab * Subclasses may store state along with each vertex and must * inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Vertex { def active: Boolean } @@ -307,6 +312,7 @@ trait Vertex { * Subclasses may contain a payload to deliver to the target vertex * and must inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Message[K] { def targetId: K } diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index c2fe6b0e286c..347ca4a7af98 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -4,7 +4,7 @@ displayTitle: Bagel Programming Guide title: Bagel --- -**Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** +**Bagel is deprecated, and superseded by [GraphX](graphx-programming-guide.html).** Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators. @@ -157,11 +157,3 @@ trait Message[K] { def targetId: K } {% endhighlight %} - -# Where to Go from Here - -Two example jobs, PageRank and shortest path, are included in `examples/src/main/scala/org/apache/spark/examples/bagel`. You can run them by passing the class name to the `bin/run-example` script included in Spark; e.g.: - - ./bin/run-example org.apache.spark.examples.bagel.WikipediaPageRank - -Each example program prints usage help when run without any arguments. diff --git a/docs/index.md b/docs/index.md index d85cf12defef..c0dc2b8d7412 100644 --- a/docs/index.md +++ b/docs/index.md @@ -90,7 +90,6 @@ options for deployment: * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing - * [Bagel (Pregel on Spark)](bagel-programming-guide.html): older, simple graph processing model **API Docs:** diff --git a/pom.xml b/pom.xml index 88ebceca769e..421357e14157 100644 --- a/pom.xml +++ b/pom.xml @@ -87,7 +87,7 @@ core - bagel + bagel graphx mllib tools From d81565465cc6d4f38b4ed78036cded630c700388 Mon Sep 17 00:00:00 2001 From: Bertrand Dechoux Date: Mon, 14 Sep 2015 09:18:46 +0100 Subject: [PATCH 0609/1824] [SPARK-9720] [ML] Identifiable types need UID in toString methods A few Identifiable types did override their toString method but without using the parent implementation. As a consequence, the uid was not present anymore in the toString result. It is the default behaviour. This patch is a quick fix. The question of enforcement is still up. No tests have been written to verify the toString method behaviour. That would be long to do because all types should be tested and not only those which have a regression now. It is possible to enforce the condition using the compiler by making the toString method final but that would introduce unwanted potential API breaking changes (see jira). Author: Bertrand Dechoux Closes #8062 from BertrandDechoux/SPARK-9720. --- .../spark/ml/classification/DecisionTreeClassifier.scala | 2 +- .../org/apache/spark/ml/classification/GBTClassifier.scala | 2 +- .../scala/org/apache/spark/ml/classification/NaiveBayes.scala | 2 +- .../spark/ml/classification/RandomForestClassifier.scala | 2 +- .../src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 4 ++-- .../apache/spark/ml/regression/DecisionTreeRegressor.scala | 2 +- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 2 +- .../apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 0a75d5d22280..b8eb49f9bdb4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -146,7 +146,7 @@ final class DecisionTreeClassificationModel private[ml] ( } override def toString: String = { - s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes" + s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3073a2a61ce8..ad8683648b97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -200,7 +200,7 @@ final class GBTClassificationModel( } override def toString: String = { - s"GBTClassificationModel with $numTrees trees" + s"GBTClassificationModel (uid=$uid) with $numTrees trees" } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 69cb88a7e671..082ea1ffad58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -198,7 +198,7 @@ class NaiveBayesModel private[ml] ( } override def toString: String = { - s"NaiveBayesModel with ${pi.size} classes" + s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 11a6d7246833..a6ebee1bb10a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -193,7 +193,7 @@ final class RandomForestClassificationModel private[ml] ( } override def toString: String = { - s"RandomForestClassificationModel with $numTrees trees" + s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index a7fa50444209..dcd6fe3c406a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -129,7 +129,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R override def copy(extra: ParamMap): RFormula = defaultCopy(extra) - override def toString: String = s"RFormula(${get(formula)})" + override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" } /** @@ -171,7 +171,7 @@ class RFormulaModel private[feature]( override def copy(extra: ParamMap): RFormulaModel = copyValues( new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${resolvedFormula})" + override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)" private def transformLabel(dataset: DataFrame): DataFrame = { val labelName = resolvedFormula.label diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index a2bcd67401d0..d9a244bea28d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -118,7 +118,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def toString: String = { - s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes" + s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes" } /** Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index b66e61f37dd5..d841ecb9e58d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -189,7 +189,7 @@ final class GBTRegressionModel( } override def toString: String = { - s"GBTRegressionModel with $numTrees trees" + s"GBTRegressionModel (uid=$uid) with $numTrees trees" } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2f36da371f57..ddb7214416a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -155,7 +155,7 @@ final class RandomForestRegressionModel private[ml] ( } override def toString: String = { - s"RandomForestRegressionModel with $numTrees trees" + s"RandomForestRegressionModel (uid=$uid) with $numTrees trees" } /** From 32407bfd2bdbf84d65cacfa7554dae6a2332bc37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 14 Sep 2015 11:51:39 -0700 Subject: [PATCH 0610/1824] [SPARK-9899] [SQL] log warning for direct output committer with speculation enabled This is a follow-up of https://github.com/apache/spark/pull/8317. When speculation is enabled, there may be multiply tasks writing to the same path. Generally it's OK as we will write to a temporary directory first and only one task can commit the temporary directory to target path. However, when we use direct output committer, tasks will write data to target path directly without temporary directory. This causes problems like corrupted data. Please see [PR comment](https://github.com/apache/spark/pull/8191#issuecomment-131598385) for more details. Unfortunately, we don't have a simple flag to tell if a output committer will write to temporary directory or not, so for safety, we have to disable any customized output committer when `speculation` is true. Author: Wenchen Fan Closes #8687 from cloud-fan/direct-committer. --- .../apache/spark/rdd/PairRDDFunctions.scala | 44 ++++++++++++++++--- .../hive/execution/InsertIntoHiveTable.scala | 17 ++++++- .../spark/sql/hive/hiveWriterContainers.scala | 1 - 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 199d79b811d6..a981b63942e6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1018,6 +1018,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsHadoopFile( path: String, @@ -1030,10 +1035,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val hadoopConf = conf hadoopConf.setOutputKeyClass(keyClass) hadoopConf.setOutputValueClass(valueClass) - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) + conf.setOutputFormat(outputFormatClass) for (c <- codec) { hadoopConf.setCompressMapOutput(true) hadoopConf.set("mapred.output.compress", "true") @@ -1047,6 +1049,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) } + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = hadoopConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + FileOutputFormat.setOutputPath(hadoopConf, SparkHadoopWriter.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) @@ -1057,6 +1072,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Configuration object for that storage system. The Conf should set an OutputFormat and any * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). @@ -1115,6 +1135,20 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobCommitter.getClass.getSimpleName + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + jobCommitter.setupJob(jobTaskContext) self.context.runJob(self, writeShard) jobCommitter.commitJob(jobTaskContext) @@ -1129,7 +1163,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1157,7 +1190,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 58f7fa640e8a..0c700bdb370a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -62,7 +62,7 @@ case class InsertIntoHiveTable( def output: Seq[Attribute] = Seq.empty - def saveAsHiveFile( + private def saveAsHiveFile( rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, @@ -178,6 +178,19 @@ case class InsertIntoHiveTable( val jobConf = new JobConf(sc.hiveconf) val jobConfSer = new SerializableJobConf(jobConf) + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = sqlContext.sparkContext.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 29a6f08f4072..4ca8042d2236 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils From cf2821ef5fd9965eb6256e8e8b3f1e00c0788098 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 14 Sep 2015 12:06:23 -0700 Subject: [PATCH 0611/1824] [SPARK-10584] [DOC] [SQL] Documentation about spark.sql.hive.metastore.version is wrong. The default value of hive metastore version is 1.2.1 but the documentation says the value of `spark.sql.hive.metastore.version` is 0.13.1. Also, we cannot get the default value by `sqlContext.getConf("spark.sql.hive.metastore.version")`. Author: Kousuke Saruta Closes #8739 from sarutak/SPARK-10584. --- docs/sql-programming-guide.md | 2 +- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6a1b0fbfa1eb..a0b911d20724 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1687,7 +1687,7 @@ The following options can be used to configure the version of Hive that is used Property NameDefaultMeaning spark.sql.hive.metastore.version - 0.13.1 + 1.2.1 Version of the Hive metastore. Available options are 0.12.0 through 1.2.1. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2e791cea96b4..d37ba5ddc2d8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -111,8 +111,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * this does not necessarily need to be the same version of Hive that is used internally by * Spark SQL for execution. */ - protected[hive] def hiveMetastoreVersion: String = - getConf(HIVE_METASTORE_VERSION, hiveExecutionVersion) + protected[hive] def hiveMetastoreVersion: String = getConf(HIVE_METASTORE_VERSION) /** * The location of the jars that should be used to instantiate the HiveMetastoreClient. This @@ -202,7 +201,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: ${hiveExecutionVersion} != Metastore: ${hiveMetastoreVersion}. " + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + - s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") + s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") } // We recursively find all jars in the class loader chain, @@ -606,7 +605,11 @@ private[hive] object HiveContext { /** The version of hive used internally by Spark SQL. */ val hiveExecutionVersion: String = "1.2.1" - val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" + val HIVE_METASTORE_VERSION = stringConf("spark.sql.hive.metastore.version", + defaultValue = Some(hiveExecutionVersion), + doc = "Version of the Hive metastore. Available options are " + + s"0.12.0 through $hiveExecutionVersion.") + val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", defaultValue = Some("builtin"), doc = s""" From ce6f3f163bc667cb5da9ab4331c8bad10cc0d701 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 14 Sep 2015 12:08:52 -0700 Subject: [PATCH 0612/1824] [SPARK-10194] [MLLIB] [PYSPARK] SGD algorithms need convergenceTol parameter in Python [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382) added a ```convergenceTol``` parameter for GradientDescent-based methods in Scala. We need that parameter in Python; otherwise, Python users will not be able to adjust that behavior (or even reproduce behavior from previous releases since the default changed). Author: Yanbo Liang Closes #8457 from yanboliang/spark-10194. --- .../mllib/api/python/PythonMLLibAPI.scala | 20 +++++++++--- python/pyspark/mllib/classification.py | 17 +++++++--- python/pyspark/mllib/regression.py | 32 ++++++++++++------- 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f585aacd452e..69ce7f50709a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -132,7 +132,8 @@ private[python] class PythonMLLibAPI extends Serializable { regParam: Double, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) .setValidateData(validateData) @@ -141,6 +142,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) lrAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( lrAlg, @@ -159,7 +161,8 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val lassoAlg = new LassoWithSGD() lassoAlg.setIntercept(intercept) .setValidateData(validateData) @@ -168,6 +171,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) trainRegressionModel( lassoAlg, data, @@ -185,7 +189,8 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() ridgeAlg.setIntercept(intercept) .setValidateData(validateData) @@ -194,6 +199,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) trainRegressionModel( ridgeAlg, data, @@ -212,7 +218,8 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) .setValidateData(validateData) @@ -221,6 +228,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( SVMAlg, @@ -240,7 +248,8 @@ private[python] class PythonMLLibAPI extends Serializable { regParam: Double, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) .setValidateData(validateData) @@ -249,6 +258,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 8f27c446a66e..cb4ee8367808 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -241,7 +241,7 @@ class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.01, regType="l2", intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a logistic regression model on the given data. @@ -274,11 +274,13 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), regType, - bool(intercept), bool(validateData)) + bool(intercept), bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -439,7 +441,7 @@ class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType="l2", - intercept=False, validateData=True): + intercept=False, validateData=True, convergenceTol=0.001): """ Train a support vector machine on the given data. @@ -472,11 +474,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, regType, - bool(intercept), bool(validateData)) + bool(intercept), bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) @@ -600,12 +604,15 @@ class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): :param miniBatchFraction: Fraction of data on which SGD is run for each iteration. :param regParam: L2 Regularization parameter. + :param convergenceTol: A condition which decides iteration termination. """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01): + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01, + convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations self.regParam = regParam self.miniBatchFraction = miniBatchFraction + self.convergenceTol = convergenceTol self._model = None super(StreamingLogisticRegressionWithSGD, self).__init__( model=self._model) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 41946e3674fb..256b7537fef6 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -28,7 +28,8 @@ 'LinearRegressionModel', 'LinearRegressionWithSGD', 'RidgeRegressionModel', 'RidgeRegressionWithSGD', 'LassoModel', 'LassoWithSGD', 'IsotonicRegressionModel', - 'IsotonicRegression'] + 'IsotonicRegression', 'StreamingLinearAlgorithm', + 'StreamingLinearRegressionWithSGD'] class LabeledPoint(object): @@ -202,7 +203,7 @@ class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.0, regType=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a linear regression model using Stochastic Gradient Descent (SGD). @@ -244,11 +245,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), - regType, bool(intercept), bool(validateData)) + regType, bool(intercept), bool(validateData), + float(convergenceTol)) return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) @@ -330,7 +334,7 @@ class LassoWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a regression model with L1-regularization using Stochastic Gradient Descent. @@ -362,11 +366,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), - bool(validateData)) + bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -449,7 +455,7 @@ class RidgeRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a regression model with L2-regularization using Stochastic Gradient Descent. @@ -481,11 +487,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), - bool(validateData)) + bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) @@ -636,15 +644,17 @@ class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): After training on a batch of data, the weights obtained at the end of training are used as initial weights for the next batch. - :param: stepSize Step size for each iteration of gradient descent. - :param: numIterations Total number of iterations run. - :param: miniBatchFraction Fraction of data on which SGD is run for each + :param stepSize: Step size for each iteration of gradient descent. + :param numIterations: Total number of iterations run. + :param miniBatchFraction: Fraction of data on which SGD is run for each iteration. + :param convergenceTol: A condition which decides iteration termination. """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0): + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations self.miniBatchFraction = miniBatchFraction + self.convergenceTol = convergenceTol self._model = None super(StreamingLinearRegressionWithSGD, self).__init__( model=self._model) From 8a634e9bcc671167613fb575c6c0c054fb4b3479 Mon Sep 17 00:00:00 2001 From: Nick Pritchard Date: Mon, 14 Sep 2015 13:27:45 -0700 Subject: [PATCH 0613/1824] [SPARK-10573] [ML] IndexToString output schema should be StringType Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata. Author: Nick Pritchard Closes #8751 from pnpritchard/SPARK-10573. --- .../scala/org/apache/spark/ml/feature/StringIndexer.scala | 5 ++--- .../org/apache/spark/ml/feature/StringIndexerSuite.scala | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3a4ab9a85764..2b1592930e77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap /** @@ -229,8 +229,7 @@ class IndexToString private[ml] ( val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val outputFields = inputFields :+ StructField($(outputCol), StringType) StructType(outputFields) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 05e05bdc64bb..ddcdb5f4212b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(a === b) } } + + test("IndexToString.transformSchema (SPARK-10573)") { + val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output") + val inSchema = StructType(Seq(StructField("input", DoubleType))) + val outSchema = idxToStr.transformSchema(inSchema) + assert(outSchema("output").dataType === StringType) + } } From 7e32387ae6303fd1cd32389d47df87170b841c67 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Sep 2015 14:10:54 -0700 Subject: [PATCH 0614/1824] [SPARK-10522] [SQL] Nanoseconds of Timestamp in Parquet should be positive Or Hive can't read it back correctly. Thanks vanzin for report this. Author: Davies Liu Closes #8674 from davies/positive_nano. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 12 +++++++----- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 17 ++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index d652fce3fd9b..687ca000d12b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -42,6 +42,7 @@ object DateTimeUtils { final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L + final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L @@ -190,13 +191,14 @@ object DateTimeUtils { /** * Returns Julian day and nanoseconds in a day from the number of microseconds + * + * Note: support timestamp since 4717 BC (without negative nanoseconds, compatible with Hive). */ def toJulianDay(us: SQLTimestamp): (Int, Long) = { - val seconds = us / MICROS_PER_SECOND - val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH - val secondsInDay = seconds % SECONDS_PER_DAY - val nanos = (us % MICROS_PER_SECOND) * 1000L - (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) + val julian_us = us + JULIAN_DAY_OF_EPOCH * MICROS_PER_DAY + val day = julian_us / MICROS_PER_DAY + val micros = julian_us % MICROS_PER_DAY + (day.toInt, micros * 1000L) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 1596bb79fa94..6b9a11f0ff74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -52,15 +52,14 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(ns === 0) assert(fromJulianDay(d, ns) == 0L) - val t = Timestamp.valueOf("2015-06-11 10:10:10.100") - val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) - val t1 = toJavaTimestamp(fromJulianDay(d1, ns1)) - assert(t.equals(t1)) - - val t2 = Timestamp.valueOf("2015-06-11 20:10:10.100") - val (d2, ns2) = toJulianDay(fromJavaTimestamp(t2)) - val t22 = toJavaTimestamp(fromJulianDay(d2, ns2)) - assert(t2.equals(t22)) + Seq(Timestamp.valueOf("2015-06-11 10:10:10.100"), + Timestamp.valueOf("2015-06-11 20:10:10.100"), + Timestamp.valueOf("1900-06-11 20:10:10.100")).foreach { t => + val (d, ns) = toJulianDay(fromJavaTimestamp(t)) + assert(ns > 0) + val t1 = toJavaTimestamp(fromJulianDay(d, ns)) + assert(t.equals(t1)) + } } test("SPARK-6785: java date conversion before and after epoch") { From 64f04154e3078ec7340da97e3c2b07cf24e89098 Mon Sep 17 00:00:00 2001 From: Edoardo Vacchi Date: Mon, 14 Sep 2015 14:56:04 -0700 Subject: [PATCH 0615/1824] [SPARK-6981] [SQL] Factor out SparkPlanner and QueryExecution from SQLContext Alternative to PR #6122; in this case the refactored out classes are replaced by inner classes with the same name for backwards binary compatibility * process in a lighter-weight, backwards-compatible way Author: Edoardo Vacchi Closes #6356 from evacchi/sqlctx-refactoring-lite. --- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 138 ++---------------- .../spark/sql/execution/QueryExecution.scala | 85 +++++++++++ .../spark/sql/execution/SQLExecution.scala | 2 +- .../spark/sql/execution/SparkPlanner.scala | 92 ++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 2 +- 6 files changed, 195 insertions(+), 128 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 1a687b2374f1..3e61123c145c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -114,7 +114,7 @@ private[sql] object DataFrame { @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) extends Serializable { + @DeveloperApi @transient val queryExecution: QueryExecution) extends Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4e8414af50b4..e3fdd782e6ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -38,6 +38,10 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} +import org.apache.spark.sql.execution.{Filter, _} +import org.apache.spark.sql.{execution => sparkexecution} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.sources._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} @@ -188,9 +192,11 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) - protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executeSql(sql: String): + org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan) + protected[sql] def executePlan(plan: LogicalPlan) = + new sparkexecution.QueryExecution(this, plan) @transient protected[sql] val tlSession = new ThreadLocal[SQLSession]() { @@ -781,77 +787,11 @@ class SQLContext(@transient val sparkContext: SparkContext) }.toArray } - protected[sql] class SparkPlanner extends SparkStrategies { - val sparkContext: SparkContext = self.sparkContext - - val sqlContext: SQLContext = self - - def codegenEnabled: Boolean = self.conf.codegenEnabled - - def unsafeEnabled: Boolean = self.conf.unsafeEnabled - - def numPartitions: Int = self.conf.numShufflePartitions - - def strategies: Seq[Strategy] = - experimental.extraStrategies ++ ( - DataSourceStrategy :: - DDLStrategy :: - TakeOrderedAndProject :: - HashAggregation :: - Aggregation :: - LeftSemiJoin :: - EquiJoinSelection :: - InMemoryScans :: - BasicOperators :: - CartesianProduct :: - BroadcastNestedLoopJoin :: Nil) - - /** - * Used to build table scan operators where complex projection and filtering are done using - * separate physical operators. This function returns the given scan operator with Project and - * Filter nodes added only when needed. For example, a Project operator is only used when the - * final desired output requires complex expressions to be evaluated or when columns can be - * further eliminated out after filtering has been done. - * - * The `prunePushedDownFilters` parameter is used to remove those filters that can be optimized - * away by the filter pushdown optimization. - * - * The required attributes for both filtering and expression evaluation are passed to the - * provided `scanBuilder` function so that it can avoid unnecessary column materialization. - */ - def pruneFilterProject( - projectList: Seq[NamedExpression], - filterPredicates: Seq[Expression], - prunePushedDownFilters: Seq[Expression] => Seq[Expression], - scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { - - val projectSet = AttributeSet(projectList.flatMap(_.references)) - val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = - prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) - - // Right now we still use a projection even if the only evaluation is applying an alias - // to a column. Since this is a no-op, it could be avoided. However, using this - // optimization with the current implementation would change the output schema. - // TODO: Decouple final output schema from expression evaluation so this copy can be - // avoided safely. - - if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && - filterSet.subsetOf(projectSet)) { - // When it is possible to just use column pruning to get the right projection and - // when the columns of this projection are enough to evaluate all filter conditions, - // just do a scan followed by a filter, with no extra project. - val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]) - filterCondition.map(Filter(_, scan)).getOrElse(scan) - } else { - val scan = scanBuilder((projectSet ++ filterSet).toSeq) - Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) - } - } - } + @deprecated("use org.apache.spark.sql.SparkPlanner", "1.6.0") + protected[sql] class SparkPlanner extends sparkexecution.SparkPlanner(this) @transient - protected[sql] val planner = new SparkPlanner + protected[sql] val planner: sparkexecution.SparkPlanner = new sparkexecution.SparkPlanner(this) @transient protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) @@ -898,59 +838,9 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val conf: SQLConf = new SQLConf } - /** - * :: DeveloperApi :: - * The primary workflow for executing relational queries using Spark. Designed to allow easy - * access to the intermediate phases of query execution for developers. - */ - @DeveloperApi - protected[sql] class QueryExecution(val logical: LogicalPlan) { - def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) - - lazy val analyzed: LogicalPlan = analyzer.execute(logical) - lazy val withCachedData: LogicalPlan = { - assertAnalyzed() - cacheManager.useCachedData(analyzed) - } - lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) - - // TODO: Don't just pick the first one... - lazy val sparkPlan: SparkPlan = { - SparkPlan.currentContext.set(self) - planner.plan(optimizedPlan).next() - } - // executedPlan should not be used to initialize any SparkPlan. It should be - // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) - - /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[InternalRow] = executedPlan.execute() - - protected def stringOrError[A](f: => A): String = - try f.toString catch { case e: Throwable => e.toString } - - def simpleString: String = - s"""== Physical Plan == - |${stringOrError(executedPlan)} - """.stripMargin.trim - - override def toString: String = { - def output = - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") - - s"""== Parsed Logical Plan == - |${stringOrError(logical)} - |== Analyzed Logical Plan == - |${stringOrError(output)} - |${stringOrError(analyzed)} - |== Optimized Logical Plan == - |${stringOrError(optimizedPlan)} - |== Physical Plan == - |${stringOrError(executedPlan)} - |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} - """.stripMargin.trim - } - } + @deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0") + protected[sql] class QueryExecution(logical: LogicalPlan) + extends sparkexecution.QueryExecution(this, logical) /** * Parses the data type in our internal string representation. The data type string should diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala new file mode 100644 index 000000000000..7bb4133a2905 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.{Experimental, DeveloperApi} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{InternalRow, optimizer} +import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * :: DeveloperApi :: + * The primary workflow for executing relational queries using Spark. Designed to allow easy + * access to the intermediate phases of query execution for developers. + */ +@DeveloperApi +class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { + val analyzer = sqlContext.analyzer + val optimizer = sqlContext.optimizer + val planner = sqlContext.planner + val cacheManager = sqlContext.cacheManager + val prepareForExecution = sqlContext.prepareForExecution + + def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) + + lazy val analyzed: LogicalPlan = analyzer.execute(logical) + lazy val withCachedData: LogicalPlan = { + assertAnalyzed() + cacheManager.useCachedData(analyzed) + } + lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) + + // TODO: Don't just pick the first one... + lazy val sparkPlan: SparkPlan = { + SparkPlan.currentContext.set(sqlContext) + planner.plan(optimizedPlan).next() + } + // executedPlan should not be used to initialize any SparkPlan. It should be + // only used for execution. + lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) + + /** Internal version of the RDD. Avoids copies and has no schema */ + lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + + protected def stringOrError[A](f: => A): String = + try f.toString catch { case e: Throwable => e.toString } + + def simpleString: String = + s"""== Physical Plan == + |${stringOrError(executedPlan)} + """.stripMargin.trim + + + override def toString: String = { + def output = + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") + + s"""== Parsed Logical Plan == + |${stringOrError(logical)} + |== Analyzed Logical Plan == + |${stringOrError(output)} + |${stringOrError(analyzed)} + |== Optimized Logical Plan == + |${stringOrError(optimizedPlan)} + |== Physical Plan == + |${stringOrError(executedPlan)} + |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} + """.stripMargin.trim + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index cee58218a885..1422e15549c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -37,7 +37,7 @@ private[sql] object SQLExecution { * we can connect them with an execution. */ def withNewExecutionId[T]( - sqlContext: SQLContext, queryExecution: SQLContext#QueryExecution)(body: => T): T = { + sqlContext: SQLContext, queryExecution: QueryExecution)(body: => T): T = { val sc = sqlContext.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) if (oldExecutionId == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala new file mode 100644 index 000000000000..b346f43faebe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources.DataSourceStrategy + +@Experimental +class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { + val sparkContext: SparkContext = sqlContext.sparkContext + + def codegenEnabled: Boolean = sqlContext.conf.codegenEnabled + + def unsafeEnabled: Boolean = sqlContext.conf.unsafeEnabled + + def numPartitions: Int = sqlContext.conf.numShufflePartitions + + def strategies: Seq[Strategy] = + sqlContext.experimental.extraStrategies ++ ( + DataSourceStrategy :: + DDLStrategy :: + TakeOrderedAndProject :: + HashAggregation :: + Aggregation :: + LeftSemiJoin :: + EquiJoinSelection :: + InMemoryScans :: + BasicOperators :: + CartesianProduct :: + BroadcastNestedLoopJoin :: Nil) + + /** + * Used to build table scan operators where complex projection and filtering are done using + * separate physical operators. This function returns the given scan operator with Project and + * Filter nodes added only when needed. For example, a Project operator is only used when the + * final desired output requires complex expressions to be evaluated or when columns can be + * further eliminated out after filtering has been done. + * + * The `prunePushedDownFilters` parameter is used to remove those filters that can be optimized + * away by the filter pushdown optimization. + * + * The required attributes for both filtering and expression evaluation are passed to the + * provided `scanBuilder` function so that it can avoid unnecessary column materialization. + */ + def pruneFilterProject( + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + prunePushedDownFilters: Seq[Expression] => Seq[Expression], + scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val filterCondition = + prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) + + // Right now we still use a projection even if the only evaluation is applying an alias + // to a column. Since this is a no-op, it could be avoided. However, using this + // optimization with the current implementation would change the output schema. + // TODO: Decouple final output schema from expression evaluation so this copy can be + // avoided safely. + + if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]) + filterCondition.map(Filter(_, scan)).getOrElse(scan) + } else { + val scan = scanBuilder((projectSet ++ filterSet).toSeq) + Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4572d5efc92b..5e40d7768904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { - self: SQLContext#SparkPlanner => + self: SparkPlanner => object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { From 217e4964444f4e07b894b1bca768a0cbbe799ea0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 14 Sep 2015 15:00:27 -0700 Subject: [PATCH 0616/1824] [SPARK-9996] [SPARK-9997] [SQL] Add local expand and NestedLoopJoin operators This PR is in conflict with #8535 and #8573. Will update this one when they are merged. Author: zsxwing Closes #8642 from zsxwing/expand-nest-join. --- .../sql/execution/local/ExpandNode.scala | 60 +++++ .../spark/sql/execution/local/LocalNode.scala | 55 +++- .../execution/local/NestedLoopJoinNode.scala | 156 ++++++++++++ .../sql/execution/local/ExpandNodeSuite.scala | 51 ++++ .../execution/local/HashJoinNodeSuite.scala | 14 - .../sql/execution/local/LocalNodeTest.scala | 14 + .../local/NestedLoopJoinNodeSuite.scala | 239 ++++++++++++++++++ 7 files changed, 574 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala new file mode 100644 index 000000000000..2aff156d18b5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -0,0 +1,60 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection} + +case class ExpandNode( + conf: SQLConf, + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LocalNode) extends UnaryLocalNode(conf) { + + assert(projections.size > 0) + + private[this] var result: InternalRow = _ + private[this] var idx: Int = _ + private[this] var input: InternalRow = _ + private[this] var groups: Array[Projection] = _ + + override def open(): Unit = { + child.open() + groups = projections.map(ee => newProjection(ee, child.output)).toArray + idx = groups.length + } + + override def next(): Boolean = { + if (idx >= groups.length) { + if (child.next()) { + input = child.fetch() + idx = 0 + } else { + return false + } + } + result = groups(idx)(input) + idx += 1 + true + } + + override def fetch(): InternalRow = result + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index e540ef8555eb..9840080e1695 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.StructType @@ -69,6 +69,18 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging */ def close(): Unit + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true + /** * Returns the content through the [[Iterator]] interface. */ @@ -91,6 +103,28 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging result } + protected def newProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } + } + } else { + new InterpretedProjection(expressions, inputSchema) + } + } + protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { @@ -113,6 +147,25 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging } } + protected def newPredicate( + expression: Expression, + inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + if (codegenEnabled) { + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } + } + } else { + InterpretedPredicate.create(expression, inputSchema) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala new file mode 100644 index 000000000000..7321fc66b4dd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.util.collection.{BitSet, CompactBuffer} + +case class NestedLoopJoinNode( + conf: SQLConf, + left: LocalNode, + right: LocalNode, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException( + s"NestedLoopJoin should not take $x as the JoinType") + } + } + + private[this] def genResultProjection: InternalRow => InternalRow = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + + private[this] var currentRow: InternalRow = _ + + private[this] var iterator: Iterator[InternalRow] = _ + + override def open(): Unit = { + val (streamed, build) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + build.open() + val buildRelation = new CompactBuffer[InternalRow] + while (build.next()) { + buildRelation += build.fetch().copy() + } + build.close() + + val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + val joinedRow = new JoinedRow + val matchedBuildTuples = new BitSet(buildRelation.size) + val resultProj = genResultProjection + streamed.open() + + // streamedRowMatches also contains null rows if using outer join + val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow => + val matchedRows = new CompactBuffer[InternalRow] + + var i = 0 + var streamRowMatched = false + + // Scan the build relation to look for matches for each streamed row + while (i < buildRelation.size) { + val buildRow = buildRelation(i) + buildSide match { + case BuildRight => joinedRow(streamedRow, buildRow) + case BuildLeft => joinedRow(buildRow, streamedRow) + } + if (boundCondition(joinedRow)) { + matchedRows += resultProj(joinedRow).copy() + streamRowMatched = true + matchedBuildTuples.set(i) + } + i += 1 + } + + // If this row had no matches and we're using outer join, join it with the null rows + if (!streamRowMatched) { + (joinType, buildSide) match { + case (LeftOuter | FullOuter, BuildRight) => + matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() + case (RightOuter | FullOuter, BuildLeft) => + matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() + case _ => + } + } + + matchedRows.iterator + } + + // If we're using outer join, find rows on the build side that didn't match anything + // and join them with the null row + lazy val unmatchedBuildRows: Iterator[InternalRow] = { + var i = 0 + buildRelation.filter { row => + val r = !matchedBuildTuples.get(i) + i += 1 + r + }.iterator + } + iterator = (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) } + case (LeftOuter | FullOuter, BuildLeft) => + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) } + case _ => streamedRowMatches + } + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala new file mode 100644 index 000000000000..cfa7f3f6dcb9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -0,0 +1,51 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +class ExpandNodeSuite extends LocalNodeTest { + + import testImplicits._ + + test("expand") { + val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value") + checkAnswer( + input, + node => + ExpandNode(conf, Seq( + Seq( + input.col("key") + input.col("value"), input.col("key") - input.col("value") + ).map(_.expr), + Seq( + input.col("key") * input.col("value"), input.col("key") / input.col("value") + ).map(_.expr) + ), node.output, node), + Seq( + (2, 0), + (1, 1), + (4, 0), + (4, 1), + (6, 0), + (9, 1), + (8, 0), + (16, 1), + (10, 0), + (25, 1) + ).toDF().collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 43b6f06aead8..78d891351f4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -24,20 +24,6 @@ class HashJoinNodeSuite extends LocalNodeTest { import testImplicits._ - private def wrapForUnsafe( - f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - if (conf.unsafeEnabled) { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } else { - f - } - } - def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { test(s"$suiteName: inner join with one match per row") { withSQLConf(confPairs: _*) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index b95d4ea7f8f2..86dd28064cc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -27,6 +27,20 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext { def conf: SQLConf = sqlContext.conf + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + if (conf.unsafeEnabled) { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } + } else { + f + } + } + /** * Runs the LocalNode and makes sure the answer matches the expected result. * @param input the input data to be used. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala new file mode 100644 index 000000000000..b1ef26ba82f1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -0,0 +1,239 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + +class NestedLoopJoinNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def joinSuite( + suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = { + test(s"$suiteName: left outer join") { + withSQLConf(confPairs: _*) { + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + lowerCaseData.col("n") > 1).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + upperCaseData.col("N") > 1).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect()) + } + } + + test(s"$suiteName: right outer join") { + withSQLConf(confPairs: _*) { + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("n") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + upperCaseData.col("N") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect()) + } + } + + test(s"$suiteName: full outer join") { + withSQLConf(confPairs: _*) { + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("n") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + upperCaseData.col("N") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect()) + } + } + } + + joinSuite( + "general-build-left", + BuildLeft, + SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite( + "general-build-right", + BuildRight, + SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite( + "tungsten-build-left", + BuildLeft, + SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") + joinSuite( + "tungsten-build-right", + BuildRight, + SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") +} From 16b6d18613e150c7038c613992d80a7828413e66 Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Mon, 14 Sep 2015 15:02:38 -0700 Subject: [PATCH 0617/1824] [SPARK-10594] [YARN] Remove reference to --num-executors, add --properties-file `ApplicationMaster` no longer has the `--num-executors` flag, and had an undocumented `--properties-file` configuration option. cc srowen Author: Erick Tryzelaar Closes #8754 from erickt/master. --- .../apache/spark/deploy/yarn/ApplicationMasterArguments.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index b08412414aa1..17d9943c795e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -105,9 +105,9 @@ class ApplicationMasterArguments(val args: Array[String]) { | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) + | --properties-file FILE Path to a custom Spark properties file. """.stripMargin) // scalastyle:on println System.exit(exitCode) From 4e2242bb41dda922573046c00c5142745632f95f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 14 Sep 2015 15:03:51 -0700 Subject: [PATCH 0618/1824] [SPARK-10576] [BUILD] Move .java files out of src/main/scala Move .java files in `src/main/scala` to `src/main/java` root, except for `package-info.java` (to stay next to package.scala) Author: Sean Owen Closes #8736 from srowen/SPARK-10576. --- .../org/apache/spark/annotation/AlphaComponent.java | 0 .../{scala => java}/org/apache/spark/annotation/DeveloperApi.java | 0 .../{scala => java}/org/apache/spark/annotation/Experimental.java | 0 .../main/{scala => java}/org/apache/spark/annotation/Private.java | 0 .../{scala => java}/org/apache/spark/graphx/TripletFields.java | 0 .../org/apache/spark/graphx/impl/EdgeActiveness.java | 0 .../org/apache/spark/sql/types/SQLUserDefinedType.java | 0 .../org/apache/spark/streaming/StreamingContextState.java | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename core/src/main/{scala => java}/org/apache/spark/annotation/AlphaComponent.java (100%) rename core/src/main/{scala => java}/org/apache/spark/annotation/DeveloperApi.java (100%) rename core/src/main/{scala => java}/org/apache/spark/annotation/Experimental.java (100%) rename core/src/main/{scala => java}/org/apache/spark/annotation/Private.java (100%) rename graphx/src/main/{scala => java}/org/apache/spark/graphx/TripletFields.java (100%) rename graphx/src/main/{scala => java}/org/apache/spark/graphx/impl/EdgeActiveness.java (100%) rename sql/catalyst/src/main/{scala => java}/org/apache/spark/sql/types/SQLUserDefinedType.java (100%) rename streaming/src/main/{scala => java}/org/apache/spark/streaming/StreamingContextState.java (100%) diff --git a/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java b/core/src/main/java/org/apache/spark/annotation/AlphaComponent.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java rename to core/src/main/java/org/apache/spark/annotation/AlphaComponent.java diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/java/org/apache/spark/annotation/DeveloperApi.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java rename to core/src/main/java/org/apache/spark/annotation/DeveloperApi.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Experimental.java b/core/src/main/java/org/apache/spark/annotation/Experimental.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Experimental.java rename to core/src/main/java/org/apache/spark/annotation/Experimental.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Private.java b/core/src/main/java/org/apache/spark/annotation/Private.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Private.java rename to core/src/main/java/org/apache/spark/annotation/Private.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/java/org/apache/spark/graphx/TripletFields.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java rename to graphx/src/main/java/org/apache/spark/graphx/TripletFields.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java b/graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java rename to graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java b/streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java similarity index 100% rename from streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java rename to streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java From ffbbc2c58b9bf1e2abc2ea797feada6821ab4de8 Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Mon, 14 Sep 2015 15:05:19 -0700 Subject: [PATCH 0619/1824] [SPARK-10549] scala 2.11 spark on yarn with security - Repl doesn't work Make this lazy so that it can set the yarn mode before creating the securityManager. Author: Tom Graves Author: Thomas Graves Closes #8719 from tgravescs/SPARK-10549. --- .../scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index be31eb2eda54..627148df80c1 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -35,7 +35,8 @@ object Main extends Logging { s.processArguments(List("-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-classpath", getAddedJars.mkString(File.pathSeparator)), true) - val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) + // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed + lazy val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. From fd1e8cddf2635c55fec2ac6e1f1c221c9685af0f Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Mon, 14 Sep 2015 15:07:13 -0700 Subject: [PATCH 0620/1824] [SPARK-10543] [CORE] Peak Execution Memory Quantile should be Per-task Basis Read `PEAK_EXECUTION_MEMORY` using `update` to get per task partial value instead of cumulative value. I tested with this workload: ```scala val size = 1000 val repetitions = 10 val data = sc.parallelize(1 to size, 5).map(x => (util.Random.nextInt(size / repetitions),util.Random.nextDouble)).toDF("key", "value") val res = data.toDF.groupBy("key").agg(sum("value")).count ``` Before: ![image](https://cloud.githubusercontent.com/assets/4317392/9828197/07dd6874-58b8-11e5-9bd9-6ba927c38b26.png) After: ![image](https://cloud.githubusercontent.com/assets/4317392/9828151/a5ddff30-58b7-11e5-8d31-eda5dc4eae79.png) Tasks view: ![image](https://cloud.githubusercontent.com/assets/4317392/9828199/17dc2b84-58b8-11e5-92a8-be89ce4d29d1.png) cc andrewor14 I appreciate if you can give feedback on this since I think you introduced display of this metric. Author: Forest Fang Closes #8726 from saurfang/stagepage. --- .../org/apache/spark/ui/jobs/StagePage.scala | 2 +- .../org/apache/spark/ui/StagePageSuite.scala | 29 ++++++++++++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 4adc6596ba21..2b71f55b7bb4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -368,7 +368,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => info.accumulables .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.value.toLong } + .map { acc => acc.update.getOrElse("0").toLong } .getOrElse(0L) .toDouble } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 3388c6dca81f..86699e7f5695 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -23,7 +23,7 @@ import scala.xml.Node import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success} +import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} @@ -47,6 +47,14 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { assert(html3.contains(targetString)) } + test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf(false).set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + // verify min/25/50/75/max show task value not cumulative values + assert(html.contains("10.0 b" * 5)) + } + /** * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. @@ -67,12 +75,19 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") - val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false) - jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) - jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markSuccessful() - jobListener.onTaskEnd( - SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness + (1 to 2).foreach { + taskId => + val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) + val peakExecutionMemory = 10 + taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, + Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) + jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markSuccessful() + jobListener.onTaskEnd( + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) } From 7b6c856367b9c36348e80e83959150da9656c4dd Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 14 Sep 2015 15:09:43 -0700 Subject: [PATCH 0621/1824] [SPARK-10564] ThreadingSuite: assertion failures in threads don't fail the test (round 2) This is a follow-up patch to #8723. I missed one case there. Author: Andrew Or Closes #8727 from andrewor14/fix-threading-suite. --- .../org/apache/spark/ThreadingSuite.scala | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index cda2b245526f..a96a4ce201c2 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -147,12 +147,12 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { }.start() } sem.acquire(2) + throwable.foreach { t => throw t } if (ThreadingSuiteState.failed.get()) { logError("Waited 1 second without seeing runningThreads = 4 (it was " + ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } - throwable.foreach { t => throw t } } test("set local properties in different thread") { @@ -178,8 +178,8 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - assert(sc.getLocalProperty("test") === null) throwable.foreach { t => throw t } + assert(sc.getLocalProperty("test") === null) } test("set and get local properties in parent-children thread") { @@ -207,15 +207,16 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) + throwable.foreach { t => throw t } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) - throwable.foreach { t => throw t } } test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { val jobStarted = new Semaphore(0) val jobEnded = new Semaphore(0) @volatile var jobResult: JobResult = null + var throwable: Option[Throwable] = None sc = new SparkContext("local", "test") sc.setJobGroup("originalJobGroupId", "description") @@ -232,14 +233,19 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { // Create a new thread which will inherit the current thread's properties val thread = new Thread() { override def run(): Unit = { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") + // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task + try { + sc.parallelize(1 to 100).foreach { x => + Thread.sleep(100) + } + } catch { + case s: SparkException => // ignored so that we don't print noise in test logs } } catch { - case s: SparkException => // ignored so that we don't print noise in test logs + case t: Throwable => + throwable = Some(t) } } } @@ -252,6 +258,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { // modification of the properties object should not affect the properties of running jobs sc.cancelJobGroup("originalJobGroupId") jobEnded.tryAcquire(10, TimeUnit.SECONDS) + throwable.foreach { t => throw t } assert(jobResult.isInstanceOf[JobFailed]) } } From 1a0955250bb65cd6f5818ad60efb62ea4b45d18e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 14 Sep 2015 21:47:40 -0400 Subject: [PATCH 0622/1824] [SPARK-9851] Support submitting map stages individually in DAGScheduler This patch adds support for submitting map stages in a DAG individually so that we can make downstream decisions after seeing statistics about their output, as part of SPARK-9850. I also added more comments to many of the key classes in DAGScheduler. By itself, the patch is not super useful except maybe to switch between a shuffle and broadcast join, but with the other subtasks of SPARK-9850 we'll be able to do more interesting decisions. The main entry point is SparkContext.submitMapStage, which lets you run a map stage and see stats about the map output sizes. Other stats could also be collected through accumulators. See AdaptiveSchedulingSuite for a short example. Author: Matei Zaharia Closes #8180 from mateiz/spark-9851. --- .../apache/spark/MapOutputStatistics.scala | 27 ++ .../org/apache/spark/MapOutputTracker.scala | 49 +++- .../scala/org/apache/spark/SparkContext.scala | 17 ++ .../apache/spark/scheduler/ActiveJob.scala | 34 ++- .../apache/spark/scheduler/DAGScheduler.scala | 251 +++++++++++++++--- .../spark/scheduler/DAGSchedulerEvent.scala | 10 + .../apache/spark/scheduler/ResultStage.scala | 17 +- .../spark/scheduler/ShuffleMapStage.scala | 13 +- .../org/apache/spark/scheduler/Stage.scala | 26 +- .../scala/org/apache/spark/FailureSuite.scala | 21 ++ .../scheduler/AdaptiveSchedulingSuite.scala | 65 +++++ .../spark/scheduler/DAGSchedulerSuite.scala | 243 ++++++++++++++++- 12 files changed, 710 insertions(+), 63 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/MapOutputStatistics.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala new file mode 100644 index 000000000000..f8a6f1d0d8cb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Holds statistics about the output sizes in a map stage. May become a DeveloperApi in the future. + * + * @param shuffleId ID of the shuffle + * @param bytesByPartitionId approximate number of output bytes for each map output partition + * (may be inexact due to use of compressed map statuses) + */ +private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a38759278385..94eb8daa85c5 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io._ +import java.util.Arrays import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} @@ -132,13 +133,43 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") - val startTime = System.currentTimeMillis + val statuses = getStatuses(shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } + } + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + val statuses = getStatuses(dep.shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { // Someone else is fetching it; wait for them to be done @@ -160,7 +191,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } if (fetchedStatuses == null) { - // We won the race to fetch the output locs; do so + // We won the race to fetch the statuses; do so logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { @@ -175,22 +206,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } - logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " + + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + s"${System.currentTimeMillis - startTime} ms") if (fetchedStatuses != null) { - fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) - } + return fetchedStatuses } else { logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) } } else { - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) - } + return statuses } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e27b3c496222..dee6091ce3ca 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1984,6 +1984,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new SimpleFutureAction(waiter, resultFunc) } + /** + * Submit a map stage for execution. This is currently an internal API only, but might be + * promoted to DeveloperApi in the future. + */ + private[spark] def submitMapStage[K, V, C](dependency: ShuffleDependency[K, V, C]) + : SimpleFutureAction[MapOutputStatistics] = { + assertNotStopped() + val callSite = getCallSite() + var result: MapOutputStatistics = null + val waiter = dagScheduler.submitMapStage( + dependency, + (r: MapOutputStatistics) => { result = r }, + callSite, + localProperties.get) + new SimpleFutureAction[MapOutputStatistics](waiter, result) + } + /** * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] * for more information. diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 50a69379412d..a3d2db31301b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -23,18 +23,42 @@ import org.apache.spark.TaskContext import org.apache.spark.util.CallSite /** - * Tracks information about an active job in the DAGScheduler. + * A running job in the DAGScheduler. Jobs can be of two types: a result job, which computes a + * ResultStage to execute an action, or a map-stage job, which computes the map outputs for a + * ShuffleMapStage before any downstream stages are submitted. The latter is used for adaptive + * query planning, to look at map output statistics before submitting later stages. We distinguish + * between these two types of jobs using the finalStage field of this class. + * + * Jobs are only tracked for "leaf" stages that clients directly submitted, through DAGScheduler's + * submitJob or submitMapStage methods. However, either type of job may cause the execution of + * other earlier stages (for RDDs in the DAG it depends on), and multiple jobs may share some of + * these previous stages. These dependencies are managed inside DAGScheduler. + * + * @param jobId A unique ID for this job. + * @param finalStage The stage that this job computes (either a ResultStage for an action or a + * ShuffleMapStage for submitMapStage). + * @param callSite Where this job was initiated in the user's program (shown on UI). + * @param listener A listener to notify if tasks in this job finish or the job fails. + * @param properties Scheduling properties attached to the job, such as fair scheduler pool name. */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: ResultStage, - val func: (TaskContext, Iterator[_]) => _, - val partitions: Array[Int], + val finalStage: Stage, val callSite: CallSite, val listener: JobListener, val properties: Properties) { - val numPartitions = partitions.length + /** + * Number of partitions we need to compute for this job. Note that result stages may not need + * to compute all partitions in their target RDD, for actions like first() and lookup(). + */ + val numPartitions = finalStage match { + case r: ResultStage => r.partitions.length + case m: ShuffleMapStage => m.rdd.partitions.length + } + + /** Which partitions of the stage have finished */ val finished = Array.fill[Boolean](numPartitions)(false) + var numFinished = 0 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 09e963f5cdf6..b4f90e834789 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -45,17 +45,65 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a * minimal schedule to run the job. It then submits stages as TaskSets to an underlying - * TaskScheduler implementation that runs them on the cluster. + * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent + * tasks that can run right away based on the data that's already on the cluster (e.g. map output + * files from previous stages), though it may fail if this data becomes unavailable. * - * In addition to coming up with a DAG of stages, this class also determines the preferred + * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with + * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks + * in each stage, but operations with shuffle dependencies require multiple stages (one to write a + * set of map output files, and another to read those files after a barrier). In the end, every + * stage will have only shuffle dependencies on other stages, and may compute multiple operations + * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of + * various RDDs (MappedRDD, FilteredRDD, etc). + * + * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * When looking through this code, there are several key concepts: + * + * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. + * For example, when the user calls an action, like count(), a job will be submitted through + * submitJob. Each Job may require the execution of multiple stages to build intermediate data. + * + * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each + * task computes the same function on partitions of the same RDD. Stages are separated at shuffle + * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to + * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that + * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. + * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. + * + * - Tasks are individual units of work, each sent to one machine. + * + * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them + * and likewise remembers which shuffle map stages have already produced output files to avoid + * redoing the map side of a shuffle. + * + * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based + * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. + * + * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, + * to prevent memory leaks in a long-running application. + * + * To recover from failures, the same stage might need to run multiple times, which are called + * "attempts". If the TaskScheduler reports that a task failed because a map output file from a + * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a + * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small + * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost + * stage(s) that compute the missing tasks. As part of this process, we might also have to create + * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since + * tasks from the old attempt of a stage could still be running, care must be taken to map any + * events received in the correct Stage object. + * * Here's a checklist to use when making or reviewing changes to this class: * + * - All data structures should be cleared when the jobs involving them end to avoid indefinite + * accumulation of state in long-running programs. + * * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to * include the new structure. This will help to catch memory leaks. */ @@ -295,12 +343,12 @@ class DAGScheduler( */ private def newResultStage( rdd: RDD[_], - numTasks: Int, + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) - val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) - + val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -500,12 +548,25 @@ class DAGScheduler( jobIdToStageIds -= job.jobId jobIdToActiveJob -= job.jobId activeJobs -= job - job.finalStage.resultOfJob = None + job.finalStage match { + case r: ResultStage => + r.resultOfJob = None + case m: ShuffleMapStage => + m.mapStageJobs = m.mapStageJobs.filter(_ != job) + } } /** - * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object + * Submit an action job to the scheduler and get a JobWaiter object back. The JobWaiter object * can be used to block until the the job finishes executing or can be used to cancel the job. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name */ def submitJob[T, U]( rdd: RDD[T], @@ -524,6 +585,7 @@ class DAGScheduler( val jobId = nextJobId.getAndIncrement() if (partitions.size == 0) { + // Return immediately if the job is running 0 tasks return new JobWaiter[U](this, jobId, 0, resultHandler) } @@ -536,6 +598,18 @@ class DAGScheduler( waiter } + /** + * Run an action job on the given RDD and pass all the results to the resultHandler function as + * they arrive. Throws an exception if the job fials, or returns normally if successful. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -559,6 +633,17 @@ class DAGScheduler( } } + /** + * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator + * as they arrive. Returns a partial result object from the evaluator. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param callSite where in the user program this job was called + * @param timeout maximum time to wait for the job, in milliseconds + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runApproximateJob[T, U, R]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -575,6 +660,41 @@ class DAGScheduler( listener.awaitResult() // Will throw an exception if the job fails } + /** + * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter + * can be used to block until the the job finishes executing or can be used to cancel the job. + * This method is used for adaptive query planning, to run map stages and look at statistics + * about their outputs before submitting downstream stages. + * + * @param dependency the ShuffleDependency to run a map stage for + * @param callback function called with the result of the job, which in this case will be a + * single MapOutputStatistics object showing how much data was produced for each partition + * @param callSite where in the user program this job was submitted + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def submitMapStage[K, V, C]( + dependency: ShuffleDependency[K, V, C], + callback: MapOutputStatistics => Unit, + callSite: CallSite, + properties: Properties): JobWaiter[MapOutputStatistics] = { + + val rdd = dependency.rdd + val jobId = nextJobId.getAndIncrement() + if (rdd.partitions.length == 0) { + throw new SparkException("Can't run submitMapStage on RDD with 0 partitions") + } + + // We create a JobWaiter with only one "task", which will be marked as complete when the whole + // map stage has completed, and will be passed the MapOutputStatistics for that stage. + // This makes it easier to avoid race conditions between the user code and the map output + // tracker that might result if we told the user the stage had finished, but then they queries + // the map output tracker and some node failures had caused the output statistics to be lost. + val waiter = new JobWaiter(this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r)) + eventProcessLoop.post(MapStageSubmitted( + jobId, dependency, callSite, waiter, SerializationUtils.clone(properties))) + waiter + } + /** * Cancel a job that is running or waiting in the queue. */ @@ -583,6 +703,9 @@ class DAGScheduler( eventProcessLoop.post(JobCancelled(jobId)) } + /** + * Cancel all jobs in the given job group ID. + */ def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) @@ -720,31 +843,77 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newResultStage(finalRDD, partitions.length, jobId, callSite) + finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return } - if (finalStage != null) { - val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) - clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions".format( - job.jobId, callSite.shortForm, partitions.length)) - logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val jobSubmissionTime = clock.getTimeMillis() - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.resultOfJob = Some(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + submitWaitingStages() + } + + private[scheduler] def handleMapStageSubmitted(jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties) { + // Submitting this map stage might still require the creation of some parent stages, so make + // sure that happens. + var finalStage: ShuffleMapStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = getShuffleMapStage(dependency, jobId) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return + } + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got map stage job %s (%s) with %d output partitions".format( + jobId, callSite.shortForm, dependency.rdd.partitions.size)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.mapStageJobs = job :: finalStage.mapStageJobs + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + // If the whole stage has already finished, tell the listener and remove it + if (!finalStage.outputLocs.contains(Nil)) { + markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency)) } + submitWaitingStages() } @@ -814,7 +983,7 @@ class DAGScheduler( case s: ResultStage => val job = s.resultOfJob.get partitionsToCompute.map { id => - val p = job.partitions(id) + val p = s.partitions(id) (id, getPreferredLocs(stage.rdd, p)) }.toMap } @@ -844,7 +1013,7 @@ class DAGScheduler( case stage: ShuffleMapStage => closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() } taskBinary = sc.broadcast(taskBinaryBytes) @@ -875,7 +1044,7 @@ class DAGScheduler( case stage: ResultStage => val job = stage.resultOfJob.get partitionsToCompute.map { id => - val p: Int = job.partitions(id) + val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, @@ -1052,13 +1221,21 @@ class DAGScheduler( logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) - .map(_._2).mkString(", ")) + .map(_._2).mkString(", ")) submitStage(shuffleStage) + } else { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } } // Note: newly runnable stages will be submitted below when we submit waiting stages } - } + } case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") @@ -1412,6 +1589,17 @@ class DAGScheduler( Nil } + /** Mark a map stage job as finished with the given output stats, and report to its listener. */ + def markMapStageJobAsFinished(job: ActiveJob, stats: MapOutputStatistics): Unit = { + // In map stage jobs, we only create a single "task", which is to finish all of the stage + // (including reusing any previous map outputs, etc); so we just mark task 0 as done + job.finished(0) = true + job.numFinished += 1 + job.listener.taskSucceeded(0, stats) + cleanupStateForJobAndIndependentStages(job) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + def stop() { logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() @@ -1445,6 +1633,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) + case MapStageSubmitted(jobId, dependency, callSite, listener, properties) => + dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties) + case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index f72a52e85dc1..dda3b6cc7f96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -35,6 +35,7 @@ import org.apache.spark.util.CallSite */ private[scheduler] sealed trait DAGSchedulerEvent +/** A result-yielding job was submitted on a target RDD */ private[scheduler] case class JobSubmitted( jobId: Int, finalRDD: RDD[_], @@ -45,6 +46,15 @@ private[scheduler] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent +/** A map stage as submitted to run as a separate job */ +private[scheduler] case class MapStageSubmitted( + jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties = null) + extends DAGSchedulerEvent + private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index bf81b9aca481..c0451da1f024 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -17,23 +17,30 @@ package org.apache.spark.scheduler +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * The ResultStage represents the final stage in a job. + * ResultStages apply a function on some partitions of an RDD to compute the result of an action. + * The ResultStage object captures the function to execute, `func`, which will be applied to each + * partition, and the set of partition IDs, `partitions`. Some stages may not run on all partitions + * of the RDD, for actions like first() and lookup(). */ private[spark] class ResultStage( id: Int, rdd: RDD[_], - numTasks: Int, + val func: (TaskContext, Iterator[_]) => _, + val partitions: Array[Int], parents: List[Stage], firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite) { - // The active job for this result stage. Will be empty if the job has already finished - // (e.g., because the job was cancelled). + /** + * The active job for this result stage. Will be empty if the job has already finished + * (e.g., because the job was cancelled). + */ var resultOfJob: Option[ActiveJob] = None override def toString: String = "ResultStage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 48d8d8e9c4b7..7d9296087640 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -23,7 +23,15 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** - * The ShuffleMapStage represents the intermediate stages in a job. + * ShuffleMapStages are intermediate stages in the execution DAG that produce data for a shuffle. + * They occur right before each shuffle operation, and might contain multiple pipelined operations + * before that (e.g. map and filter). When executed, they save map output files that can later be + * fetched by reduce tasks. The `shuffleDep` field describes the shuffle each stage is part of, + * and variables like `outputLocs` and `numAvailableOutputs` track how many map outputs are ready. + * + * ShuffleMapStages can also be submitted independently as jobs with DAGScheduler.submitMapStage. + * For such stages, the ActiveJobs that submitted them are tracked in `mapStageJobs`. Note that + * there can be multiple ActiveJobs trying to compute the same shuffle map stage. */ private[spark] class ShuffleMapStage( id: Int, @@ -37,6 +45,9 @@ private[spark] class ShuffleMapStage( override def toString: String = "ShuffleMapStage " + id + /** Running map-stage jobs that were submitted to execute this stage independently (if any) */ + var mapStageJobs: List[ActiveJob] = Nil + var numAvailableOutputs: Int = 0 def isAvailable: Boolean = numAvailableOutputs == numPartitions diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index c086535782c2..b37eccbd0f7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -24,27 +24,33 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * A stage is a set of independent tasks all computing the same function that need to run as part + * A stage is a set of parallel tasks all computing the same function that need to run as part * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the * DAGScheduler runs these stages in topological order. * * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for - * another stage, or a result stage, in which case its tasks directly compute the action that - * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes - * that each output partition is on. + * other stage(s), or a result stage, in which case its tasks directly compute a Spark action + * (e.g. count(), save(), etc) by running a function on an RDD. For shuffle map stages, we also + * track the nodes that each output partition is on. * * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * - * The callSite provides a location in user code which relates to the stage. For a shuffle map - * stage, the callSite gives the user code that created the RDD being shuffled. For a result - * stage, the callSite gives the user code that executes the associated action (e.g. count()). - * - * A single stage can consist of multiple attempts. In that case, the latestInfo field will - * be updated for each attempt. + * Finally, a single stage can be re-executed in multiple attempts due to fault recovery. In that + * case, the Stage object will track multiple StageInfo objects to pass to listeners or the web UI. + * The latest one will be accessible through latestInfo. * + * @param id Unique stage ID + * @param rdd RDD that this stage runs on: for a shuffle map stage, it's the RDD we run map tasks + * on, while for a result stage, it's the target RDD that we ran an action on + * @param numTasks Total number of tasks in stage; result stages in particular may not need to + * compute all partitions, e.g. for first(), lookup(), and take(). + * @param parents List of stages that this stage depends on (through shuffle dependencies). + * @param firstJobId ID of the first job this stage was part of, for FIFO scheduling. + * @param callSite Location in the user program associated with this stage: either where the target + * RDD was created, for a shuffle map stage, or where the action for a result stage was called. */ private[scheduler] abstract class Stage( val id: Int, diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index aa50a49c5023..f58756e6f617 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -217,6 +217,27 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { FailureSuiteState.clear() } + // Run a 3-task map stage where one task fails once. + test("failure in tasks in a submitMapStage") { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 1 && FailureSuiteState.tasksFailed == 0) { + FailureSuiteState.tasksFailed += 1 + throw new Exception("Intentional task failure") + } + } + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + sc.submitMapStage(dep).get() + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala new file mode 100644 index 000000000000..3fe28027c3c2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.rdd.{ShuffledRDDPartition, RDD, ShuffledRDD} +import org.apache.spark._ + +object AdaptiveSchedulingSuiteState { + var tasksRun = 0 + + def clear(): Unit = { + tasksRun = 0 + } +} + +/** A special ShuffledRDD where we can pass a ShuffleDependency object to use */ +class CustomShuffledRDD[K, V, C](@transient dep: ShuffleDependency[K, V, C]) + extends RDD[(K, C)](dep.rdd.context, Seq(dep)) { + + override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[(K, C)]] + } + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](dep.partitioner.numPartitions)(i => new ShuffledRDDPartition(i)) + } +} + +class AdaptiveSchedulingSuite extends SparkFunSuite with LocalSparkContext { + test("simple use of submitMapStage") { + try { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.parallelize(1 to 3, 3).map { x => + AdaptiveSchedulingSuiteState.tasksRun += 1 + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + val shuffled = new CustomShuffledRDD[Int, Int, Int](dep) + sc.submitMapStage(dep).get() + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + assert(shuffled.collect().toSet == Set((1, 1), (2, 2), (3, 3))) + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + } finally { + AdaptiveSchedulingSuiteState.clear() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1b9ff740ff53..1c55f90ad9b4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -152,6 +152,14 @@ class DAGSchedulerSuite override def jobFailed(exception: Exception) = { failure = exception } } + /** A simple helper class for creating custom JobListeners */ + class SimpleListener extends JobListener { + val results = new HashMap[Int, Any] + var failure: Exception = null + override def taskSucceeded(index: Int, result: Any): Unit = results.put(index, result) + override def jobFailed(exception: Exception): Unit = { failure = exception } + } + before { sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.submittedStageInfos.clear() @@ -229,7 +237,7 @@ class DAGSchedulerSuite } } - /** Sends the rdd to the scheduler for scheduling and returns the job id. */ + /** Submits a job to the scheduler and returns the job id. */ private def submit( rdd: RDD[_], partitions: Array[Int], @@ -240,6 +248,15 @@ class DAGSchedulerSuite jobId } + /** Submits a map stage to the scheduler and returns the job id. */ + private def submitMapStage( + shuffleDep: ShuffleDependency[_, _, _], + listener: JobListener = jobListener): Int = { + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener)) + jobId + } + /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { runEvent(TaskSetFailed(taskSet, message, None)) @@ -1313,6 +1330,230 @@ class DAGSchedulerSuite assert(stackTraceString.contains("org.scalatest.FunSuite")) } + test("simple map stage submission") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + assert(results.size === 0) // No results yet + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage; it should directly do the reduce + submit(reduceRdd, Array(0)) + completeNextResultStageWithSuccess(2, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with reduce stage also depending on the data") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + + // Submit the map stage by itself + submitMapStage(shuffleDep) + + // Submit a reduce job that depends on this map stage + submit(reduceRdd, Array(0)) + + // Complete tasks for the map stage + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + + // Complete tasks for the reduce stage + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with fetch failure") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage, but where one reduce will fail a fetch + submit(reduceRdd, Array(0, 1)) + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch + // from, then TaskSet 3 will run the reduce stage + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() + + // Run another reduce job without a failure; this should just work + submit(reduceRdd, Array(0, 1)) + complete(taskSets(4), Seq( + (Success, 44), + (Success, 45))) + assert(results === Map(0 -> 44, 1 -> 45)) + results.clear() + assertDataStructuresEmpty() + + // Resubmit the map stage; this should also just work + submitMapStage(shuffleDep) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + } + + /** + * In this test, we have three RDDs with shuffle dependencies, and we submit map stage jobs + * that are waiting on each one, as well as a reduce job on the last one. We test that all of + * these jobs complete even if there are some fetch failures in both shuffles. + */ + test("map stage submission with multiple shared stages and failures") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1)) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + val rdd3 = new MyRDD(sc, 2, List(dep2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + val listener3 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + submit(rdd3, Array(0, 1), listener = listener3) + + // Complete the first stage + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.size)), + (Success, makeMapStatus("hostB", rdd1.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting the second stage, show a fetch failure + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostA", rdd2.partitions.size)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + assert(listener2.results.size === 0) // Second stage listener should not have a result yet + + // Stage 0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(listener2.results.size === 0) // Second stage listener should still not have a result + + // Stage 1 should now be running as task set 3; make its first task succeed + assert(taskSets(3).stageId === 1) + complete(taskSets(3), Seq( + (Success, makeMapStatus("hostB", rdd2.partitions.size)), + (Success, makeMapStatus("hostD", rdd2.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) + assert(listener2.results.size === 1) + + // Finally, the reduce job should be running as task set 4; make it see a fetch failure, + // then make it run again and succeed + assert(taskSets(4).stageId === 2) + complete(taskSets(4), Seq( + (Success, 52), + (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 + assert(taskSets(5).stageId === 1) + complete(taskSets(5), Seq( + (Success, makeMapStatus("hostE", rdd2.partitions.size)))) + complete(taskSets(6), Seq( + (Success, 53))) + assert(listener3.results === Map(0 -> 52, 1 -> 53)) + assertDataStructuresEmpty() + } + + /** + * In this test, we run a map stage where one of the executors fails but we still receive a + * "zombie" complete message from that executor. We want to make sure the stage is not reported + * as done until all tasks have completed. + */ + test("map stage submission with executor failure late map task completions") { + val shuffleMapRdd = new MyRDD(sc, 3, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + + submitMapStage(shuffleDep) + + val oldTaskSet = taskSets(0) + runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Pretend host A was lost + val oldEpoch = mapOutputTracker.getEpoch + runEvent(ExecutorLost("exec-hostA")) + val newEpoch = mapOutputTracker.getEpoch + assert(newEpoch > oldEpoch) + + // Suppose we also get a completed event from task 1 on the same host; this should be ignored + runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // A completion from another task should work because it's a non-failed host + runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Now complete tasks in the second task set + val newTaskSet = taskSets(1) + assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA + runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 1) // Map stage job should now finally be complete + assertDataStructuresEmpty() + + // Also test that a reduce stage using this shuffled data can immediately run + val reduceRDD = new MyRDD(sc, 2, List(shuffleDep)) + results.clear() + submit(reduceRDD, Array(0, 1)) + complete(taskSets(2), Seq((Success, 42), (Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. From 55204181004c105c7a3e8c31a099b37e48bfd953 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Sep 2015 19:46:34 -0700 Subject: [PATCH 0623/1824] [SPARK-10542] [PYSPARK] fix serialize namedtuple Author: Davies Liu Closes #8707 from davies/fix_namedtuple. --- python/pyspark/cloudpickle.py | 15 ++++++++++++++- python/pyspark/serializers.py | 1 + python/pyspark/tests.py | 5 +++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 3b647985801b..95b3abc74244 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -350,6 +350,11 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ + # workaround for namedtuple (hijacked by PySpark) + if getattr(obj, '_is_namedtuple_', False): + self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) + return + self.save(_load_class) self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) d.pop('__doc__', None) @@ -382,7 +387,7 @@ def save_instancemethod(self, obj): self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) else: self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), - obj=obj) + obj=obj) dispatch[types.MethodType] = save_instancemethod def save_inst(self, obj): @@ -744,6 +749,14 @@ def _load_class(cls, d): return cls +def _load_namedtuple(name, fields): + """ + Loads a class generated by namedtuple + """ + from collections import namedtuple + return namedtuple(name, fields) + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 411b4dbf481f..2a1326947f4f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -359,6 +359,7 @@ def _hack_namedtuple(cls): def __reduce__(self): return (_restore, (name, fields, tuple(self))) cls.__reduce__ = __reduce__ + cls._is_namedtuple_ = True return cls diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8bfed074c905..647504c32f15 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -218,6 +218,11 @@ def test_namedtuple(self): p2 = loads(dumps(p1, 2)) self.assertEqual(p1, p2) + from pyspark.cloudpickle import dumps + P2 = loads(dumps(P)) + p3 = P2(1, 3) + self.assertEqual(p1, p3) + def test_itemgetter(self): from operator import itemgetter ser = CloudPickleSerializer() From 4ae4d54794778042b2cc983e52757edac02412ab Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 14 Sep 2015 21:37:43 -0700 Subject: [PATCH 0624/1824] [SPARK-9793] [MLLIB] [PYSPARK] PySpark DenseVector, SparseVector implement __eq__ and __hash__ correctly PySpark DenseVector, SparseVector ```__eq__``` method should use semantics equality, and DenseVector can compared with SparseVector. Implement PySpark DenseVector, SparseVector ```__hash__``` method based on the first 16 entries. That will make PySpark Vector objects can be used in collections. Author: Yanbo Liang Closes #8166 from yanboliang/spark-9793. --- python/pyspark/mllib/linalg/__init__.py | 90 ++++++++++++++++++++----- python/pyspark/mllib/tests.py | 32 +++++++++ 2 files changed, 107 insertions(+), 15 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 334dc8e38bb8..380f86e9b44f 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -25,6 +25,7 @@ import sys import array +import struct if sys.version >= '3': basestring = str @@ -122,6 +123,13 @@ def _format_float_list(l): return [_format_float(x) for x in l] +def _double_to_long_bits(value): + if np.isnan(value): + value = float('nan') + # pack double into 64 bits, then unpack as long int + return struct.unpack('Q', struct.pack('d', value))[0] + + class VectorUDT(UserDefinedType): """ SQL user-defined type (UDT) for Vector. @@ -404,11 +412,31 @@ def __repr__(self): return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) def __eq__(self, other): - return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) + if isinstance(other, DenseVector): + return np.array_equal(self.array, other.array) + elif isinstance(other, SparseVector): + if len(self) != other.size: + return False + return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values) + return False def __ne__(self, other): return not self == other + def __hash__(self): + size = len(self) + result = 31 + size + nnz = 0 + i = 0 + while i < size and nnz < 128: + if self.array[i] != 0: + result = 31 * result + i + bits = _double_to_long_bits(self.array[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + def __getattr__(self, item): return getattr(self.array, item) @@ -704,20 +732,14 @@ def __repr__(self): return "SparseVector({0}, {{{1}}})".format(self.size, entries) def __eq__(self, other): - """ - Test SparseVectors for equality. - - >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v1 == v2 - True - >>> v1 != v2 - False - """ - return (isinstance(other, self.__class__) - and other.size == self.size - and np.array_equal(other.indices, self.indices) - and np.array_equal(other.values, self.values)) + if isinstance(other, SparseVector): + return other.size == self.size and np.array_equal(other.indices, self.indices) \ + and np.array_equal(other.values, self.values) + elif isinstance(other, DenseVector): + if self.size != len(other): + return False + return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array) + return False def __getitem__(self, index): inds = self.indices @@ -739,6 +761,19 @@ def __getitem__(self, index): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + result = 31 + self.size + nnz = 0 + i = 0 + while i < len(self.values) and nnz < 128: + if self.values[i] != 0: + result = 31 * result + int(self.indices[i]) + bits = _double_to_long_bits(self.values[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + class Vectors(object): @@ -841,6 +876,31 @@ def parse(s): def zeros(size): return DenseVector(np.zeros(size)) + @staticmethod + def _equals(v1_indices, v1_values, v2_indices, v2_values): + """ + Check equality between sparse/dense vectors, + v1_indices and v2_indices assume to be strictly increasing. + """ + v1_size = len(v1_values) + v2_size = len(v2_values) + k1 = 0 + k2 = 0 + all_equal = True + while all_equal: + while k1 < v1_size and v1_values[k1] == 0: + k1 += 1 + while k2 < v2_size and v2_values[k2] == 0: + k2 += 1 + + if k1 >= v1_size or k2 >= v2_size: + return k1 >= v1_size and k2 >= v2_size + + all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2] + k1 += 1 + k2 += 1 + return all_equal + class Matrix(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5097c5e8ba4c..636f9a06cab7 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -194,6 +194,38 @@ def test_squared_distance(self): self.assertEquals(3.0, _squared_distance(sv, arr)) self.assertEquals(3.0, _squared_distance(sv, narr)) + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEquals(hash(v1), hash(v2)) + self.assertEquals(hash(v1), hash(v3)) + self.assertEquals(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEquals(v1, v2) + self.assertEquals(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + def test_conversion(self): # numpy arrays should be automatically upcast to float64 # tests for fix of [SPARK-5089] From 610971ecfe858b1a48ce69b25614afe52bcbe77f Mon Sep 17 00:00:00 2001 From: noelsmith Date: Mon, 14 Sep 2015 21:58:52 -0700 Subject: [PATCH 0625/1824] [SPARK-10273] Add @since annotation to pyspark.mllib.feature Duplicated the since decorator from pyspark.sql into pyspark (also tweaked to handle functions without docstrings). Added since to methods + "versionadded::" to classes (derived from the git file history in pyspark). Author: noelsmith Closes #8633 from noel-smith/SPARK-10273-since-mllib-feature. --- python/pyspark/mllib/feature.py | 58 ++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index f921e3ad1a31..7b077b058c3f 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JJavaError -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( @@ -84,11 +84,14 @@ class Normalizer(VectorTransformer): >>> nor2 = Normalizer(float("inf")) >>> nor2.transform(v) DenseVector([0.0, 0.5, 1.0]) + + .. versionadded:: 1.2.0 """ def __init__(self, p=2.0): assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) + @since('1.2.0') def transform(self, vector): """ Applies unit length normalization on a vector. @@ -133,7 +136,11 @@ class StandardScalerModel(JavaVectorTransformer): .. note:: Experimental Represents a StandardScaler model that can transform vectors. + + .. versionadded:: 1.2.0 """ + + @since('1.2.0') def transform(self, vector): """ Applies standardization transformation on a vector. @@ -149,6 +156,7 @@ def transform(self, vector): """ return JavaVectorTransformer.transform(self, vector) + @since('1.4.0') def setWithMean(self, withMean): """ Setter of the boolean which decides @@ -157,6 +165,7 @@ def setWithMean(self, withMean): self.call("setWithMean", withMean) return self + @since('1.4.0') def setWithStd(self, withStd): """ Setter of the boolean which decides @@ -189,6 +198,8 @@ class StandardScaler(object): >>> for r in result.collect(): r DenseVector([-0.7071, 0.7071, -0.7071]) DenseVector([0.7071, -0.7071, 0.7071]) + + .. versionadded:: 1.2.0 """ def __init__(self, withMean=False, withStd=True): if not (withMean or withStd): @@ -196,6 +207,7 @@ def __init__(self, withMean=False, withStd=True): self.withMean = withMean self.withStd = withStd + @since('1.2.0') def fit(self, dataset): """ Computes the mean and variance and stores as a model to be used @@ -215,7 +227,11 @@ class ChiSqSelectorModel(JavaVectorTransformer): .. note:: Experimental Represents a Chi Squared selector model. + + .. versionadded:: 1.4.0 """ + + @since('1.4.0') def transform(self, vector): """ Applies transformation on a vector. @@ -245,10 +261,13 @@ class ChiSqSelector(object): SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) DenseVector([5.0]) + + .. versionadded:: 1.4.0 """ def __init__(self, numTopFeatures): self.numTopFeatures = int(numTopFeatures) + @since('1.4.0') def fit(self, data): """ Returns a ChiSquared feature selector. @@ -265,6 +284,8 @@ def fit(self, data): class PCAModel(JavaVectorTransformer): """ Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + + .. versionadded:: 1.5.0 """ @@ -281,6 +302,8 @@ class PCA(object): 1.648... >>> pcArray[1] -4.013... + + .. versionadded:: 1.5.0 """ def __init__(self, k): """ @@ -288,6 +311,7 @@ def __init__(self, k): """ self.k = int(k) + @since('1.5.0') def fit(self, data): """ Computes a [[PCAModel]] that contains the principal components of the input vectors. @@ -312,14 +336,18 @@ class HashingTF(object): >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) + + .. versionadded:: 1.2.0 """ def __init__(self, numFeatures=1 << 20): self.numFeatures = numFeatures + @since('1.2.0') def indexOf(self, term): """ Returns the index of the input term. """ return hash(term) % self.numFeatures + @since('1.2.0') def transform(self, document): """ Transforms the input document (list of terms) to term frequency @@ -339,7 +367,10 @@ def transform(self, document): class IDFModel(JavaVectorTransformer): """ Represents an IDF model that can transform term frequency vectors. + + .. versionadded:: 1.2.0 """ + @since('1.2.0') def transform(self, x): """ Transforms term frequency (TF) vectors to TF-IDF vectors. @@ -358,6 +389,7 @@ def transform(self, x): """ return JavaVectorTransformer.transform(self, x) + @since('1.4.0') def idf(self): """ Returns the current IDF vector. @@ -401,10 +433,13 @@ class IDF(object): DenseVector([0.0, 0.0, 1.3863, 0.863]) >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0))) SparseVector(4, {1: 0.0, 3: 0.5754}) + + .. versionadded:: 1.2.0 """ def __init__(self, minDocFreq=0): self.minDocFreq = minDocFreq + @since('1.2.0') def fit(self, dataset): """ Computes the inverse document frequency. @@ -420,7 +455,10 @@ def fit(self, dataset): class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model + + .. versionadded:: 1.2.0 """ + @since('1.2.0') def transform(self, word): """ Transforms a word to its vector representation @@ -435,6 +473,7 @@ def transform(self, word): except Py4JJavaError: raise ValueError("%s not found" % word) + @since('1.2.0') def findSynonyms(self, word, num): """ Find synonyms of a word @@ -450,6 +489,7 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + @since('1.4.0') def getVectors(self): """ Returns a map of words to their vector representations. @@ -457,7 +497,11 @@ def getVectors(self): return self.call("getVectors") @classmethod + @since('1.5.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ jmodel = sc._jvm.org.apache.spark.mllib.feature \ .Word2VecModel.load(sc._jsc.sc(), path) return Word2VecModel(jmodel) @@ -507,6 +551,8 @@ class Word2Vec(object): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.2.0 """ def __init__(self): """ @@ -519,6 +565,7 @@ def __init__(self): self.seed = random.randint(0, sys.maxsize) self.minCount = 5 + @since('1.2.0') def setVectorSize(self, vectorSize): """ Sets vector size (default: 100). @@ -526,6 +573,7 @@ def setVectorSize(self, vectorSize): self.vectorSize = vectorSize return self + @since('1.2.0') def setLearningRate(self, learningRate): """ Sets initial learning rate (default: 0.025). @@ -533,6 +581,7 @@ def setLearningRate(self, learningRate): self.learningRate = learningRate return self + @since('1.2.0') def setNumPartitions(self, numPartitions): """ Sets number of partitions (default: 1). Use a small number for @@ -541,6 +590,7 @@ def setNumPartitions(self, numPartitions): self.numPartitions = numPartitions return self + @since('1.2.0') def setNumIterations(self, numIterations): """ Sets number of iterations (default: 1), which should be smaller @@ -549,6 +599,7 @@ def setNumIterations(self, numIterations): self.numIterations = numIterations return self + @since('1.2.0') def setSeed(self, seed): """ Sets random seed. @@ -556,6 +607,7 @@ def setSeed(self, seed): self.seed = seed return self + @since('1.4.0') def setMinCount(self, minCount): """ Sets minCount, the minimum number of times a token must appear @@ -564,6 +616,7 @@ def setMinCount(self, minCount): self.minCount = minCount return self + @since('1.2.0') def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -596,10 +649,13 @@ class ElementwiseProduct(VectorTransformer): >>> rdd = sc.parallelize([a, b]) >>> eprod.transform(rdd).collect() [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] + + .. versionadded:: 1.5.0 """ def __init__(self, scalingVector): self.scalingVector = _convert_to_vector(scalingVector) + @since('1.5.0') def transform(self, vector): """ Computes the Hadamard product of the vector. From a2249359d5b0368318a714b292bb1d0dc70c0e27 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 14 Sep 2015 21:59:40 -0700 Subject: [PATCH 0626/1824] [SPARK-10275] [MLLIB] Add @since annotation to pyspark.mllib.random Author: Yu ISHIKAWA Closes #8666 from yu-iskw/SPARK-10275. --- python/pyspark/mllib/random.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 06fbc0eb6aef..9c733b1332bc 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -21,6 +21,7 @@ from functools import wraps +from pyspark import since from pyspark.mllib.common import callMLlibFunc @@ -39,9 +40,12 @@ class RandomRDDs(object): """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. + + .. addedversion:: 1.1.0 """ @staticmethod + @since("1.1.0") def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the @@ -72,6 +76,7 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) @staticmethod + @since("1.1.0") def normalRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the standard normal @@ -100,6 +105,7 @@ def normalRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) @staticmethod + @since("1.3.0") def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the log normal @@ -132,6 +138,7 @@ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): size, numPartitions, seed) @staticmethod + @since("1.1.0") def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Poisson @@ -158,6 +165,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod + @since("1.3.0") def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Exponential @@ -184,6 +192,7 @@ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod + @since("1.3.0") def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Gamma @@ -216,6 +225,7 @@ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.1.0") def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -241,6 +251,7 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.1.0") def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -266,6 +277,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.3.0") def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -300,6 +312,7 @@ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed @staticmethod @toArray + @since("1.1.0") def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -330,6 +343,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.3.0") def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -360,6 +374,7 @@ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=No @staticmethod @toArray + @since("1.3.0") def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn From 833be73314b85b390a9007ed6ed63dc47bbd9e4f Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 14 Sep 2015 23:40:29 -0700 Subject: [PATCH 0627/1824] Small fixes to docs Links work now properly + consistent use of *Spark standalone cluster* (Spark uppercase + lowercase the rest -- seems agreed in the other places in the docs). Author: Jacek Laskowski Closes #8759 from jaceklaskowski/docs-submitting-apps. --- docs/submitting-applications.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e58645274e52..7ea4d6f1a3f8 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -65,8 +65,8 @@ For Python applications, simply pass a `.py` file in the place of ` Date: Mon, 14 Sep 2015 23:41:06 -0700 Subject: [PATCH 0628/1824] [SPARK-10598] [DOCS] Comments preceding toMessage method state: "The edge partition is encoded in the lower * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.". References to bytes should be changed to bits. This contribution is my original work and I license the work to the Spark project under it's open source license. Author: Robin East Closes #8756 from insidedctm/master. --- .../org/apache/spark/graphx/impl/RoutingTablePartition.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index eb3c997e0f3c..4f1260a5a67b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -34,7 +34,7 @@ object RoutingTablePartition { /** * A message from an edge partition to a vertex specifying the position in which the edge * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower - * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int. + * 30 bits of the Int, and the position is encoded in the upper 2 bits of the Int. */ type RoutingTableMessage = (VertexId, Int) From 09b7e7c19897549a8622aec095f27b8b38a1a4d3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Sep 2015 00:54:20 -0700 Subject: [PATCH 0629/1824] Update version to 1.6.0-SNAPSHOT. Author: Reynold Xin Closes #8350 from rxin/1.6. --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- core/src/main/scala/org/apache/spark/package.scala | 2 +- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-assembly/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt-assembly/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/java8-tests/pom.xml | 2 +- extras/kinesis-asl-assembly/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib/pom.xml | 2 +- network/common/pom.xml | 2 +- network/shuffle/pom.xml | 2 +- network/yarn/pom.xml | 2 +- pom.xml | 2 +- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 13 +++++++++++-- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- unsafe/pom.xml | 2 +- yarn/pom.xml | 2 +- 38 files changed, 49 insertions(+), 40 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index d0d7201f004a..a3a16c42a621 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.5.0 +Version: 1.6.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman diff --git a/assembly/pom.xml b/assembly/pom.xml index e9c6d26ccddc..4b60ee00ffbe 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index ed5c37e595a9..3baf8d47b4dc 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index a46292c13bcc..e31d90f60889 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8ae76c5f72f2..7515aad09db7 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.5.0-SNAPSHOT" + val SPARK_VERSION = "1.6.0-SNAPSHOT" } diff --git a/docs/_config.yml b/docs/_config.yml index c0e031a83ba9..c59cc465ef89 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.5.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.5.0 +SPARK_VERSION: 1.6.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.6.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/examples/pom.xml b/examples/pom.xml index e6884b09dca9..f5ab2a7fdc09 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 561ed4babe5d..dceedcf23ed5 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 0664cfb2021e..d7c2ac474a18 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 14f7daaf417e..132062f94fb4 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 6f4e2a89e9af..a9ed39ef8c9a 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index ded863bd985e..05abd9e2e681 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml index 841260063373..89713a28ca6a 100644 --- a/external/mqtt-assembly/pom.xml +++ b/external/mqtt-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 69b309876a0d..05e6338a08b0 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 178ae8de13b5..244ad58ae959 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 37bfd10d4366..171df8682c84 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 3636a9037d43..81794a853631 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 51af3e6f2225..61ba4787fbf9 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 521b53e230c4..6dd8ff69c294 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 478d0019a25f..87a4f05a0596 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 853dea9a7795..202fc19002d1 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 2fd768d8119c..ed38e66aa246 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index a5db14407b4f..22c0c6008ba3 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/network/common/pom.xml b/network/common/pom.xml index 4141fcb8267a..1cc054a8936c 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 3d2edf9d9451..7a66c968041c 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index a99f7c4392d3..e745180eace7 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index 421357e14157..653599464114 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index f16bf989f200..519052620246 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.4.0" + val previousSparkVersion = "1.5.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3b8b6c8ffa37..87b141cd3b05 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -35,8 +35,17 @@ object MimaExcludes { def excludes(version: String) = version match { case v if v.startsWith("1.6") => Seq( - MimaBuild.excludeSparkPackage("network") - ) + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("network"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution") + ) ++ + MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), diff --git a/repl/pom.xml b/repl/pom.xml index a5a0f1fc2c85..5cf416a4a544 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 75ab575dfde8..6cfd53e868f8 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 349007789f63..465aa3a3888c 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 3566c87dd248..f7fe085f34d8 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index be1607476e25..ac67fe5f47be 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 697895e72fe5..5cc9001b0e9a 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 298ee2348b58..1e64f280e5be 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 89475ee3cf5a..066abe92e51c 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index f6737695307a..d8e4a4bbead8 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml From c35fdcb7e9c01271ce560dba4e0bd37569c8f5d1 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 15 Sep 2015 09:58:49 -0700 Subject: [PATCH 0630/1824] [SPARK-10491] [MLLIB] move RowMatrix.dspr to BLAS jira: https://issues.apache.org/jira/browse/SPARK-10491 We implemented dspr with sparse vector support in `RowMatrix`. This method is also used in WeightedLeastSquares and other places. It would be useful to move it to `linalg.BLAS`. Let me know if new UT needed. Author: Yuhao Yang Closes #8663 from hhbyyh/movedspr. --- .../spark/ml/optim/WeightedLeastSquares.scala | 4 +- .../org/apache/spark/mllib/linalg/BLAS.scala | 44 +++++++++++++++++++ .../mllib/linalg/distributed/RowMatrix.scala | 40 +---------------- .../apache/spark/mllib/linalg/BLASSuite.scala | 25 +++++++++++ 4 files changed, 72 insertions(+), 41 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index a99e2ac4c691..0ff8931b0bab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -88,7 +88,7 @@ private[ml] class WeightedLeastSquares( if (fitIntercept) { // shift centers // A^T A - aBar aBar^T - RowMatrix.dspr(-1.0, aBar, aaValues) + BLAS.spr(-1.0, aBar, aaValues) // A^T b - bBar aBar BLAS.axpy(-bBar, aBar, abBar) } @@ -203,7 +203,7 @@ private[ml] object WeightedLeastSquares { bbSum += w * b * b BLAS.axpy(w, a, aSum) BLAS.axpy(w * b, a, abSum) - RowMatrix.dspr(w, a, aaSum.values) + BLAS.spr(w, a, aaSum) this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9ee81eda8a8c..df9f4ae145b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -236,6 +236,50 @@ private[spark] object BLAS extends Serializable with Logging { _nativeBLAS } + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix in a [[DenseVector]](column major) + */ + def spr(alpha: Double, v: Vector, U: DenseVector): Unit = { + spr(alpha, v, U.values) + } + + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix packed in an array (column major) + */ + def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = { + val n = v.size + v match { + case DenseVector(values) => + NativeBLAS.dspr("U", n, alpha, values, 1, U) + case SparseVector(size, indices, values) => + val nnz = indices.length + var colStartIdx = 0 + var prevCol = 0 + var col = 0 + var j = 0 + var i = 0 + var av = 0.0 + while (j < nnz) { + col = indices(j) + // Skip empty columns. + colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 + col = indices(j) + av = alpha * values(j) + i = 0 + while (i <= j) { + U(colStartIdx + indices(i)) += av * values(i) + i += 1 + } + j += 1 + prevCol = col + } + } + } + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 83779ac88989..e55ef26858ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} -import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ @@ -123,7 +122,7 @@ class RowMatrix @Since("1.0.0") ( // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( seqOp = (U, v) => { - RowMatrix.dspr(1.0, v, U.data) + BLAS.spr(1.0, v, U.data) U }, combOp = (U1, U2) => U1 += U2) @@ -673,43 +672,6 @@ class RowMatrix @Since("1.0.0") ( @Experimental object RowMatrix { - /** - * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. - * - * @param U the upper triangular part of the matrix packed in an array (column major) - */ - // TODO: SPARK-10491 - move this method to linalg.BLAS - private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { - // TODO: Find a better home (breeze?) for this method. - val n = v.size - v match { - case DenseVector(values) => - blas.dspr("U", n, alpha, values, 1, U) - case SparseVector(size, indices, values) => - val nnz = indices.length - var colStartIdx = 0 - var prevCol = 0 - var col = 0 - var j = 0 - var i = 0 - var av = 0.0 - while (j < nnz) { - col = indices(j) - // Skip empty columns. - colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 - col = indices(j) - av = alpha * values(j) - i = 0 - while (i <= j) { - U(colStartIdx + indices(i)) += av * values(i) - i += 1 - } - j += 1 - prevCol = col - } - } - } - /** * Fills a full square matrix from its upper triangular part. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 8db5c8424abe..96e5ffef7a13 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -126,6 +126,31 @@ class BLASSuite extends SparkFunSuite { } } + test("spr") { + // test dense vector + val alpha = 0.1 + val x = new DenseVector(Array(1.0, 2, 2.1, 4)) + val U = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + val expected = new DenseVector(Array(1.1, 2.2, 2.4, 3.21, 3.42, 3.441, 4.4, 4.8, 4.84, 5.6)) + + spr(alpha, x, U) + assert(U ~== expected absTol 1e-9) + + val matrix33 = new DenseVector(Array(1.0, 2, 3, 4, 5)) + withClue("Size of vector must match the rank of matrix") { + intercept[Exception] { + spr(alpha, x, matrix33) + } + } + + // test sparse vector + val sv = new SparseVector(4, Array(0, 3), Array(1.0, 2)) + val U2 = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + spr(0.1, sv, U2) + val expectedSparse = new DenseVector(Array(1.1, 2.0, 2.0, 3.0, 3.0, 3.0, 4.2, 4.0, 4.0, 4.4)) + assert(U2 ~== expectedSparse absTol 1e-15) + } + test("syr") { val dA = new DenseMatrix(4, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) From 8abef21dac1a6538c4e4e0140323b83d804d602b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 15 Sep 2015 10:45:02 -0700 Subject: [PATCH 0631/1824] [SPARK-10300] [BUILD] [TESTS] Add support for test tags in run-tests.py. This change does two things: - tag a few tests and adds the mechanism in the build to be able to disable those tags, both in maven and sbt, for both junit and scalatest suites. - add some logic to run-tests.py to disable some tags depending on what files have changed; that's used to disable expensive tests when a module hasn't explicitly been changed, to speed up testing for changes that don't directly affect those modules. Author: Marcelo Vanzin Closes #8437 from vanzin/test-tags. --- core/pom.xml | 10 ------- dev/run-tests.py | 19 ++++++++++++-- dev/sparktestsupport/modules.py | 24 ++++++++++++++++- external/flume/pom.xml | 10 ------- external/kafka/pom.xml | 10 ------- external/mqtt/pom.xml | 10 ------- external/twitter/pom.xml | 10 ------- external/zeromq/pom.xml | 10 ------- extras/java8-tests/pom.xml | 10 ------- extras/kinesis-asl/pom.xml | 5 ---- launcher/pom.xml | 5 ---- mllib/pom.xml | 10 ------- network/common/pom.xml | 10 ------- network/shuffle/pom.xml | 10 ------- pom.xml | 17 ++++++++++-- project/SparkBuild.scala | 13 ++++++++-- sql/core/pom.xml | 5 ---- .../execution/HiveCompatibilitySuite.scala | 2 ++ sql/hive/pom.xml | 5 ---- .../spark/sql/hive/ExtendedHiveTest.java | 26 +++++++++++++++++++ .../spark/sql/hive/client/VersionsSuite.scala | 2 ++ streaming/pom.xml | 10 ------- unsafe/pom.xml | 10 ------- .../spark/deploy/yarn/ExtendedYarnTest.java | 26 +++++++++++++++++++ .../spark/deploy/yarn/YarnClusterSuite.scala | 1 + .../yarn/YarnShuffleIntegrationSuite.scala | 1 + 26 files changed, 124 insertions(+), 147 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java create mode 100644 yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java diff --git a/core/pom.xml b/core/pom.xml index e31d90f60889..8a2018109622 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -331,16 +331,6 @@ scalacheck_${scala.binary.version} test
    - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.curator curator-test diff --git a/dev/run-tests.py b/dev/run-tests.py index d8b22e1665e7..1a816585187d 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -118,6 +118,14 @@ def determine_modules_to_test(changed_modules): return modules_to_test.union(set(changed_modules)) +def determine_tags_to_exclude(changed_modules): + tags = [] + for m in modules.all_modules: + if m not in changed_modules: + tags += m.test_tags + return tags + + # ------------------------------------------------------------------------------------------------- # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- @@ -369,6 +377,7 @@ def detect_binary_inop_with_mima(): def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] + profiles_and_goals = test_profiles + mvn_test_goals print("[info] Running Spark tests using Maven with these arguments: ", @@ -392,7 +401,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, hadoop_version, test_modules): +def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -401,6 +410,10 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): test_profiles = get_hadoop_profiles(hadoop_version) + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) + + if excluded_tags: + test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] + if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -500,8 +513,10 @@ def main(): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) + excluded_tags = determine_tags_to_exclude(changed_modules) if not changed_modules: changed_modules = [modules.root] + excluded_tags = [] print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -541,7 +556,7 @@ def main(): detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, test_modules) + run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 346452f3174e..65397f1f3e0b 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - should_run_r_tests=False): + test_tags=(), should_run_r_tests=False): """ Define a new module. @@ -50,6 +50,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param blacklisted_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. + :param test_tags A set of tags that will be excluded when running unit tests if the module + is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. """ self.name = name @@ -60,6 +62,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations + self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.dependent_modules = set() @@ -85,6 +88,9 @@ def contains_file(self, filename): "catalyst/test", "sql/test", "hive/test", + ], + test_tags=[ + "org.apache.spark.sql.hive.ExtendedHiveTest" ] ) @@ -398,6 +404,22 @@ def contains_file(self, filename): ) +yarn = Module( + name="yarn", + dependencies=[], + source_file_regexes=[ + "yarn/", + "network/yarn/", + ], + sbt_test_goals=[ + "yarn/test", + "network-yarn/test", + ], + test_tags=[ + "org.apache.spark.deploy.yarn.ExtendedYarnTest" + ] +) + # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 132062f94fb4..3154e36c21ef 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -66,16 +66,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 05abd9e2e681..7d0d46dadc72 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -86,16 +86,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 05e6338a08b0..913c47d33f48 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -58,16 +58,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.activemq activemq-core diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 244ad58ae959..9137bf25ee8a 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -58,16 +58,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 171df8682c84..6fec4f0e8a0f 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -57,16 +57,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 81794a853631..dba3dda8a956 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -58,16 +58,6 @@ test-jar test - - junit - junit - test - - - com.novocode - junit-interface - test - diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 6dd8ff69c294..760f183a2ef3 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -74,11 +74,6 @@ scalacheck_${scala.binary.version} test - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/launcher/pom.xml b/launcher/pom.xml index ed38e66aa246..80696280a1d1 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -42,11 +42,6 @@ log4j test - - junit - junit - test - org.mockito mockito-core diff --git a/mllib/pom.xml b/mllib/pom.xml index 22c0c6008ba3..5dedacb38874 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -94,16 +94,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.mockito mockito-core diff --git a/network/common/pom.xml b/network/common/pom.xml index 1cc054a8936c..9c12cca0df60 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -64,16 +64,6 @@ - - junit - junit - test - - - com.novocode - junit-interface - test - log4j log4j diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 7a66c968041c..e4f4c57b683c 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -78,16 +78,6 @@ test-jar test - - junit - junit - test - - - com.novocode - junit-interface - test - log4j log4j diff --git a/pom.xml b/pom.xml index 653599464114..2927d3e10756 100644 --- a/pom.xml +++ b/pom.xml @@ -181,6 +181,7 @@ 0.9.2 ${java.home} + @@ -1952,6 +1964,7 @@ __not_used__ + ${test.exclude.tags} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 901cfa538d23..d80d300f1c3b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -567,11 +567,20 @@ object TestSettings { javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", + // Exclude tags defined in a system property + testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, + sys.props.get("test.exclude.tags").map { tags => + tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq + }.getOrElse(Nil): _*), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, + sys.props.get("test.exclude.tags").map { tags => + Seq("--exclude-categories=" + tags) + }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), - testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", + libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 465aa3a3888c..fa6732db183d 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,11 +73,6 @@ jackson-databind ${fasterxml.jackson.version} - - junit - junit - test - org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ab309e0a1d36..ffc4c32794ca 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -24,11 +24,13 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ +@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index ac67fe5f47be..82cfeb2bb95d 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -160,11 +160,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java new file mode 100644 index 000000000000..e2183183fb55 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedHiveTest { } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index f0bb77092c0c..888d1b7b4553 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils @@ -32,6 +33,7 @@ import org.apache.spark.util.Utils * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ +@ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { // Do not use a temp path here to speed up subsequent executions of the unit test during diff --git a/streaming/pom.xml b/streaming/pom.xml index 5cc9001b0e9a..1e6ee009ca6d 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -84,21 +84,11 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.seleniumhq.selenium selenium-java test - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 066abe92e51c..4e8b9a84bb67 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -55,16 +55,6 @@ - - junit - junit - test - - - com.novocode - junit-interface - test - org.mockito mockito-core diff --git a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java new file mode 100644 index 000000000000..7a8f2fe979c1 --- /dev/null +++ b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedYarnTest { } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index b5a42fd6afd9..105c3090d489 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ +@ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 8d9c9b3004ed..4700e2428df0 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} /** * Integration test for the external shuffle service with a yarn mini-cluster */ +@ExtendedYarnTest class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = { From 7ca30b505c3561dc2832b463be4c6301a90380e4 Mon Sep 17 00:00:00 2001 From: noelsmith Date: Tue, 15 Sep 2015 12:23:20 -0700 Subject: [PATCH 0632/1824] [PYSPARK] [MLLIB] [DOCS] Replaced addversion with versionadded in mllib.random Missed this when reviewing `pyspark.mllib.random` for SPARK-10275. Author: noelsmith Closes #8773 from noel-smith/mllib-random-versionadded-fix. --- python/pyspark/mllib/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 9c733b1332bc..6a3c643b6641 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -41,7 +41,7 @@ class RandomRDDs(object): Generator methods for creating RDDs comprised of i.i.d samples from some distribution. - .. addedversion:: 1.1.0 + .. versionadded:: 1.1.0 """ @staticmethod From 0d9ab016755d5b56ce4043f229602169fd752e88 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 15 Sep 2015 12:25:31 -0700 Subject: [PATCH 0633/1824] Closes #8738 Closes #8767 Closes #2491 Closes #6795 Closes #2096 Closes #7722 From 416003b26401894ec712e1a5291a92adfbc5af01 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Tue, 15 Sep 2015 20:42:33 +0100 Subject: [PATCH 0634/1824] [DOCS] Small fixes to Spark on Yarn doc * a follow-up to 16b6d18613e150c7038c613992d80a7828413e66 as `--num-executors` flag is not suppported. * links + formatting Author: Jacek Laskowski Closes #8762 from jaceklaskowski/docs-spark-on-yarn. --- docs/running-on-yarn.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 5159ef9e3394..d1244323edff 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -18,16 +18,16 @@ Spark application's configuration (driver, executors, and the AM when running in There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.html) modes, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. + To launch a Spark application in `yarn-cluster` mode: - `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ --master yarn-cluster \ - --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 \ @@ -37,7 +37,7 @@ For example: The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. -To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. The following shows how you can run `spark-shell` in `yarn-client` mode: $ ./bin/spark-shell --master yarn-client @@ -54,8 +54,8 @@ In `yarn-cluster` mode, the driver runs on a different machine than the client, # Preparations -Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. -Binary distributions can be downloaded from the Spark project website. +Running Spark on YARN requires a binary distribution of Spark which is built with YARN support. +Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). # Configuration From b42059d2efdf3322334694205a6d951bcc291644 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 15 Sep 2015 13:03:38 -0700 Subject: [PATCH 0635/1824] Revert "[SPARK-10300] [BUILD] [TESTS] Add support for test tags in run-tests.py." This reverts commit 8abef21dac1a6538c4e4e0140323b83d804d602b. --- core/pom.xml | 10 +++++++ dev/run-tests.py | 19 ++------------ dev/sparktestsupport/modules.py | 24 +---------------- external/flume/pom.xml | 10 +++++++ external/kafka/pom.xml | 10 +++++++ external/mqtt/pom.xml | 10 +++++++ external/twitter/pom.xml | 10 +++++++ external/zeromq/pom.xml | 10 +++++++ extras/java8-tests/pom.xml | 10 +++++++ extras/kinesis-asl/pom.xml | 5 ++++ launcher/pom.xml | 5 ++++ mllib/pom.xml | 10 +++++++ network/common/pom.xml | 10 +++++++ network/shuffle/pom.xml | 10 +++++++ pom.xml | 17 ++---------- project/SparkBuild.scala | 13 ++-------- sql/core/pom.xml | 5 ++++ .../execution/HiveCompatibilitySuite.scala | 2 -- sql/hive/pom.xml | 5 ++++ .../spark/sql/hive/ExtendedHiveTest.java | 26 ------------------- .../spark/sql/hive/client/VersionsSuite.scala | 2 -- streaming/pom.xml | 10 +++++++ unsafe/pom.xml | 10 +++++++ .../spark/deploy/yarn/ExtendedYarnTest.java | 26 ------------------- .../spark/deploy/yarn/YarnClusterSuite.scala | 1 - .../yarn/YarnShuffleIntegrationSuite.scala | 1 - 26 files changed, 147 insertions(+), 124 deletions(-) delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java delete mode 100644 yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java diff --git a/core/pom.xml b/core/pom.xml index 8a2018109622..e31d90f60889 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -331,6 +331,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.apache.curator curator-test diff --git a/dev/run-tests.py b/dev/run-tests.py index 1a816585187d..d8b22e1665e7 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -118,14 +118,6 @@ def determine_modules_to_test(changed_modules): return modules_to_test.union(set(changed_modules)) -def determine_tags_to_exclude(changed_modules): - tags = [] - for m in modules.all_modules: - if m not in changed_modules: - tags += m.test_tags - return tags - - # ------------------------------------------------------------------------------------------------- # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- @@ -377,7 +369,6 @@ def detect_binary_inop_with_mima(): def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] - profiles_and_goals = test_profiles + mvn_test_goals print("[info] Running Spark tests using Maven with these arguments: ", @@ -401,7 +392,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): +def run_scala_tests(build_tool, hadoop_version, test_modules): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -410,10 +401,6 @@ def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): test_profiles = get_hadoop_profiles(hadoop_version) + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) - - if excluded_tags: - test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] - if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -513,10 +500,8 @@ def main(): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) - excluded_tags = determine_tags_to_exclude(changed_modules) if not changed_modules: changed_modules = [modules.root] - excluded_tags = [] print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -556,7 +541,7 @@ def main(): detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) + run_scala_tests(build_tool, hadoop_version, test_modules) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 65397f1f3e0b..346452f3174e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - test_tags=(), should_run_r_tests=False): + should_run_r_tests=False): """ Define a new module. @@ -50,8 +50,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param blacklisted_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. - :param test_tags A set of tags that will be excluded when running unit tests if the module - is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. """ self.name = name @@ -62,7 +60,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations - self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.dependent_modules = set() @@ -88,9 +85,6 @@ def contains_file(self, filename): "catalyst/test", "sql/test", "hive/test", - ], - test_tags=[ - "org.apache.spark.sql.hive.ExtendedHiveTest" ] ) @@ -404,22 +398,6 @@ def contains_file(self, filename): ) -yarn = Module( - name="yarn", - dependencies=[], - source_file_regexes=[ - "yarn/", - "network/yarn/", - ], - sbt_test_goals=[ - "yarn/test", - "network-yarn/test", - ], - test_tags=[ - "org.apache.spark.deploy.yarn.ExtendedYarnTest" - ] -) - # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 3154e36c21ef..132062f94fb4 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -66,6 +66,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 7d0d46dadc72..05abd9e2e681 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -86,6 +86,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 913c47d33f48..05e6338a08b0 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -58,6 +58,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.apache.activemq activemq-core diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 9137bf25ee8a..244ad58ae959 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -58,6 +58,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 6fec4f0e8a0f..171df8682c84 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -57,6 +57,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index dba3dda8a956..81794a853631 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -58,6 +58,16 @@ test-jar test + + junit + junit + test + + + com.novocode + junit-interface + test + diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 760f183a2ef3..6dd8ff69c294 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -74,6 +74,11 @@ scalacheck_${scala.binary.version} test + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/launcher/pom.xml b/launcher/pom.xml index 80696280a1d1..ed38e66aa246 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -42,6 +42,11 @@ log4j test + + junit + junit + test + org.mockito mockito-core diff --git a/mllib/pom.xml b/mllib/pom.xml index 5dedacb38874..22c0c6008ba3 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -94,6 +94,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.mockito mockito-core diff --git a/network/common/pom.xml b/network/common/pom.xml index 9c12cca0df60..1cc054a8936c 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -64,6 +64,16 @@ + + junit + junit + test + + + com.novocode + junit-interface + test + log4j log4j diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index e4f4c57b683c..7a66c968041c 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -78,6 +78,16 @@ test-jar test + + junit + junit + test + + + com.novocode + junit-interface + test + log4j log4j diff --git a/pom.xml b/pom.xml index 2927d3e10756..653599464114 100644 --- a/pom.xml +++ b/pom.xml @@ -181,7 +181,6 @@ 0.9.2 ${java.home} - @@ -1964,7 +1952,6 @@ __not_used__ - ${test.exclude.tags} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d80d300f1c3b..901cfa538d23 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -567,20 +567,11 @@ object TestSettings { javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", - // Exclude tags defined in a system property - testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, - sys.props.get("test.exclude.tags").map { tags => - tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq - }.getOrElse(Nil): _*), - testOptions in Test += Tests.Argument(TestFrameworks.JUnit, - sys.props.get("test.exclude.tags").map { tags => - Seq("--exclude-categories=" + tags) - }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), - testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", + libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. diff --git a/sql/core/pom.xml b/sql/core/pom.xml index fa6732db183d..465aa3a3888c 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,6 +73,11 @@ jackson-databind ${fasterxml.jackson.version} + + junit + junit + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ffc4c32794ca..ab309e0a1d36 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -24,13 +24,11 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ -@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 82cfeb2bb95d..ac67fe5f47be 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -160,6 +160,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java deleted file mode 100644 index e2183183fb55..000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive; - -import java.lang.annotation.*; -import org.scalatest.TagAnnotation; - -@TagAnnotation -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.TYPE}) -public @interface ExtendedHiveTest { } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 888d1b7b4553..f0bb77092c0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils @@ -33,7 +32,6 @@ import org.apache.spark.util.Utils * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ -@ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { // Do not use a temp path here to speed up subsequent executions of the unit test during diff --git a/streaming/pom.xml b/streaming/pom.xml index 1e6ee009ca6d..5cc9001b0e9a 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -84,11 +84,21 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + org.seleniumhq.selenium selenium-java test + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 4e8b9a84bb67..066abe92e51c 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -55,6 +55,16 @@ + + junit + junit + test + + + com.novocode + junit-interface + test + org.mockito mockito-core diff --git a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java deleted file mode 100644 index 7a8f2fe979c1..000000000000 --- a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn; - -import java.lang.annotation.*; -import org.scalatest.TagAnnotation; - -@TagAnnotation -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.TYPE}) -public @interface ExtendedYarnTest { } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 105c3090d489..b5a42fd6afd9 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -@ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 4700e2428df0..8d9c9b3004ed 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} /** * Integration test for the external shuffle service with a yarn mini-cluster */ -@ExtendedYarnTest class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = { From 841972e22c653ba58e9a65433fed203ff288f13a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 Sep 2015 13:33:32 -0700 Subject: [PATCH 0636/1824] [SPARK-10437] [SQL] Support aggregation expressions in Order By JIRA: https://issues.apache.org/jira/browse/SPARK-10437 If an expression in `SortOrder` is a resolved one, such as `count(1)`, the corresponding rule in `Analyzer` to make it work in order by will not be applied. Author: Liang-Chi Hsieh Closes #8599 from viirya/orderby-agg. --- .../sql/catalyst/analysis/Analyzer.scala | 14 +++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 20 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 591747b45c37..02f34cbf58ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -561,7 +561,7 @@ class Analyzer( } case sort @ Sort(sortOrder, global, aggregate: Aggregate) - if aggregate.resolved && !sort.resolved => + if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { @@ -598,9 +598,15 @@ class Analyzer( } } - Project(aggregate.output, - Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + // Since we don't rely on sort.resolved as the stop condition for this rule, + // we need to check this and prevent applying this rule multiple times + if (sortOrder == evaluatedOrderings) { + sort + } else { + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 962b100b532c..f9981356f364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1562,6 +1562,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |ORDER BY sum(b) + 1 """.stripMargin), Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT count(*) + |FROM orderByData + |GROUP BY a + |ORDER BY count(*) + """.stripMargin), + Row(2) :: Row(2) :: Row(2) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY a, count(*), sum(b) + """.stripMargin), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Nil) } test("SPARK-7952: fix the equality check between boolean and numeric types") { From 31a229aa739b6d05ec6d91b820fcca79b6b7d6fe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 15 Sep 2015 13:36:52 -0700 Subject: [PATCH 0637/1824] [SPARK-10475] [SQL] improve column prunning for Project on Sort Sometimes we can't push down the whole `Project` though `Sort`, but we still have a chance to push down part of it. Author: Wenchen Fan Closes #8644 from cloud-fan/column-prune. --- .../sql/catalyst/optimizer/Optimizer.scala | 19 +++++++++++++++---- .../optimizer/ColumnPruningSuite.scala | 11 +++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0f4caec7451a..648a65e7c0eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -228,10 +228,21 @@ object ColumnPruning extends Rule[LogicalPlan] { case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) - // Push down project if possible when the child is sort - case p @ Project(projectList, s @ Sort(_, _, grandChild)) - if s.references.subsetOf(p.outputSet) => - s.copy(child = Project(projectList, grandChild)) + // Push down project if possible when the child is sort. + case p @ Project(projectList, s @ Sort(_, _, grandChild)) => + if (s.references.subsetOf(p.outputSet)) { + s.copy(child = Project(projectList, grandChild)) + } else { + val neededReferences = s.references ++ p.references + if (neededReferences == grandChild.outputSet) { + // No column we can prune, return the original plan. + p + } else { + // Do not use neededReferences.toSeq directly, should respect grandChild's output order. + val newProjectList = grandChild.output.filter(neededReferences.contains) + p.copy(child = s.copy(child = Project(newProjectList, grandChild))) + } + } // Eliminate no-op Projects case Project(projectList, child) if child.output == projectList => child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index dbebcb86809d..4a1e7ceaf394 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -80,5 +80,16 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Column pruning for Project on Sort") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + + val query = input.orderBy('b.asc).select('a).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze + + comparePlans(optimized, correctAnswer) + } + // todo: add more tests for column pruning } From be52faa7c72fb4b95829f09a7dc5eb5dccd03524 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 15 Sep 2015 15:46:47 -0700 Subject: [PATCH 0638/1824] [SPARK-7685] [ML] Apply weights to different samples in Logistic Regression In fraud detection dataset, almost all the samples are negative while only couple of them are positive. This type of high imbalanced data will bias the models toward negative resulting poor performance. In python-scikit, they provide a correction allowing users to Over-/undersample the samples of each class according to the given weights. In auto mode, selects weights inversely proportional to class frequencies in the training set. This can be done in a more efficient way by multiplying the weights into loss and gradient instead of doing actual over/undersampling in the training dataset which is very expensive. http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html On the other hand, some of the training data maybe more important like the training samples from tenure users while the training samples from new users maybe less important. We should be able to provide another "weight: Double" information in the LabeledPoint to weight them differently in the learning algorithm. Author: DB Tsai Author: DB Tsai Closes #7884 from dbtsai/SPARK-7685. --- .../classification/LogisticRegression.scala | 199 +++++++++++------- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/param/shared/sharedParams.scala | 12 +- .../stat/MultivariateOnlineSummarizer.scala | 75 ++++--- .../LogisticRegressionSuite.scala | 102 ++++++++- .../MultivariateOnlineSummarizerSuite.scala | 27 +++ project/MimaExcludes.scala | 10 +- 7 files changed, 303 insertions(+), 128 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a460262b87e4..bd96e8d000ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,12 +29,12 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.storage.StorageLevel /** @@ -42,7 +42,7 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization with HasThreshold { + with HasStandardization with HasWeightCol with HasThreshold { /** * Set threshold in binary classification, in range [0, 1]. @@ -146,6 +146,17 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } +/** + * Class that represents an instance of weighted data point with label and features. + * + * TODO: Refactor this class to proper place. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[classification] case class Instance(label: Double, weight: Double, features: Vector) + /** * :: Experimental :: * Logistic regression. @@ -218,31 +229,42 @@ class LogisticRegression(override val uid: String) override def getThreshold: Double = super.getThreshold + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) override def getThresholds: Array[Double] = super.getThresholds override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, labelSummarizer) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new MultiClassSummarizer))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer), - (label: Double, features: Vector)) => - (summarizer.add(features), labelSummarizer.add(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, - classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer, - classSummarizer2: MultiClassSummarizer)) => - (summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2)) - }) + val (summarizer, labelSummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer), + c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp) + } val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -295,7 +317,7 @@ class LogisticRegression(override val uid: String) new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } - val initialWeightsWithIntercept = + val initialCoefficientsWithIntercept = Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) if ($(fitIntercept)) { @@ -312,14 +334,14 @@ class LogisticRegression(override val uid: String) b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - initialWeightsWithIntercept.toArray(numFeatures) - = math.log(histogram(1).toDouble / histogram(0).toDouble) + initialCoefficientsWithIntercept.toArray(numFeatures) + = math.log(histogram(1) / histogram(0)) } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialWeightsWithIntercept.toBreeze.toDenseVector) + initialCoefficientsWithIntercept.toBreeze.toDenseVector) - val (weights, intercept, objectiveHistory) = { + val (coefficients, intercept, objectiveHistory) = { /* Note that in Logistic Regression, the objective history (loss + regularization) is log-likelihood which is invariance under feature standardization. As a result, @@ -339,28 +361,29 @@ class LogisticRegression(override val uid: String) } /* - The weights are trained in the scaled space; we're converting them back to + The coefficients are trained in the scaled space; we're converting them back to the original space. Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ - val rawWeights = state.x.toArray.clone() + val rawCoefficients = state.x.toArray.clone() var i = 0 while (i < numFeatures) { - rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } i += 1 } if ($(fitIntercept)) { - (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result()) + (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, + arrayBuilder.result()) } else { - (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result()) + (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result()) } } if (handlePersistence) instances.unpersist() - val model = copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) val logRegSummary = new BinaryLogisticRegressionTrainingSummary( model.transform(dataset), $(probabilityCol), @@ -501,22 +524,29 @@ class LogisticRegressionModel private[ml] ( * corresponding joint dataset. */ private[classification] class MultiClassSummarizer extends Serializable { - private val distinctMap = new mutable.HashMap[Int, Long] + // The first element of value in distinctMap is the actually number of instances, + // and the second element of value is sum of the weights. + private val distinctMap = new mutable.HashMap[Int, (Long, Double)] private var totalInvalidCnt: Long = 0L /** * Add a new label into this MultilabelSummarizer, and update the distinct map. * @param label The label for this data point. + * @param weight The weight of this instances. * @return This MultilabelSummarizer */ - def add(label: Double): this.type = { + def add(label: Double, weight: Double = 1.0): this.type = { + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + + if (weight == 0.0) return this + if (label - label.toInt != 0.0 || label < 0) { totalInvalidCnt += 1 this } else { - val counts: Long = distinctMap.getOrElse(label.toInt, 0L) - distinctMap.put(label.toInt, counts + 1) + val (counts: Long, weightSum: Double) = distinctMap.getOrElse(label.toInt, (0L, 0.0)) + distinctMap.put(label.toInt, (counts + 1L, weightSum + weight)) this } } @@ -537,8 +567,8 @@ private[classification] class MultiClassSummarizer extends Serializable { } smallMap.distinctMap.foreach { case (key, value) => - val counts = largeMap.distinctMap.getOrElse(key, 0L) - largeMap.distinctMap.put(key, counts + value) + val (counts: Long, weightSum: Double) = largeMap.distinctMap.getOrElse(key, (0L, 0.0)) + largeMap.distinctMap.put(key, (counts + value._1, weightSum + value._2)) } largeMap.totalInvalidCnt += smallMap.totalInvalidCnt largeMap @@ -550,13 +580,13 @@ private[classification] class MultiClassSummarizer extends Serializable { /** @return The number of distinct labels in the input dataset. */ def numClasses: Int = distinctMap.keySet.max + 1 - /** @return The counts of each label in the input dataset. */ - def histogram: Array[Long] = { - val result = Array.ofDim[Long](numClasses) + /** @return The weightSum of each label in the input dataset. */ + def histogram: Array[Double] = { + val result = Array.ofDim[Double](numClasses) var i = 0 val len = result.length while (i < len) { - result(i) = distinctMap.getOrElse(i, 0L) + result(i) = distinctMap.getOrElse(i, (0L, 0.0))._2 i += 1 } result @@ -565,6 +595,8 @@ private[classification] class MultiClassSummarizer extends Serializable { /** * Abstraction for multinomial Logistic Regression Training results. + * Currently, the training summary ignores the training weights except + * for the objective trace. */ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { @@ -584,10 +616,10 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Dataframe outputted by the model's `transform` method. */ def predictions: DataFrame - /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */ + /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the the true label of each sample. */ + /** Field in "predictions" which gives the the true label of each instance. */ def labelCol: String } @@ -597,8 +629,8 @@ sealed trait LogisticRegressionSummary extends Serializable { * Logistic regression training results. * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each sample as a vector. - * @param labelCol field in "predictions" which gives the true label of each sample. + * each instance as a vector. + * @param labelCol field in "predictions" which gives the true label of each instance. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental @@ -617,8 +649,8 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * Binary Logistic regression results for a given model. * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each sample. - * @param labelCol field in "predictions" which gives the true label of each sample. + * each instance. + * @param labelCol field in "predictions" which gives the true label of each instance. */ @Experimental class BinaryLogisticRegressionSummary private[classification] ( @@ -687,14 +719,14 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used - * in binary classification for samples in sparse or dense vector in a online fashion. + * in binary classification for instances in sparse or dense vector in a online fashion. * * Note that multinomial logistic loss is not supported yet! * * Two LogisticAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. @@ -702,25 +734,25 @@ class BinaryLogisticRegressionSummary private[classification] ( * @param featuresMean The mean values of the features. */ private class LogisticAggregator( - weights: Vector, + coefficients: Vector, numClasses: Int, fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double]) extends Serializable { - private var totalCnt: Long = 0L + private var weightSum = 0.0 private var lossSum = 0.0 - private val weightsArray = weights match { + private val coefficientsArray = coefficients match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weights.getClass}.") + s"coefficients only supports dense vector but got type ${coefficients.getClass}.") } - private val dim = if (fitIntercept) weightsArray.length - 1 else weightsArray.length + private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length - private val gradientSumArray = Array.ofDim[Double](weightsArray.length) + private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length) /** * Add a new training data to this LogisticAggregator, and update the loss and gradient @@ -729,13 +761,17 @@ private class LogisticAggregator( * @param label The label for this data point. * @param data The features for one data point in dense/sparse vector format to be added * into this aggregator. + * @param weight The weight for over-/undersamples each of training instance. Default is one. * @return This LogisticAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + + def add(label: Double, data: Vector, weight: Double = 1.0): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${data.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") - val localWeightsArray = weightsArray + if (weight == 0.0) return this + + val localCoefficientsArray = coefficientsArray val localGradientSumArray = gradientSumArray numClasses match { @@ -745,13 +781,13 @@ private class LogisticAggregator( var sum = 0.0 data.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { - sum += localWeightsArray(index) * (value / featuresStd(index)) + sum += localCoefficientsArray(index) * (value / featuresStd(index)) } } - sum + { if (fitIntercept) localWeightsArray(dim) else 0.0 } + sum + { if (fitIntercept) localCoefficientsArray(dim) else 0.0 } } - val multiplier = (1.0 / (1.0 + math.exp(margin))) - label + val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label) data.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { @@ -765,15 +801,15 @@ private class LogisticAggregator( if (label > 0) { // The following is equivalent to log(1 + exp(margin)) but more numerically stable. - lossSum += MLUtils.log1pExp(margin) + lossSum += weight * MLUtils.log1pExp(margin) } else { - lossSum += MLUtils.log1pExp(margin) - margin + lossSum += weight * (MLUtils.log1pExp(margin) - margin) } case _ => new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " + "binary classification for now.") } - totalCnt += 1 + weightSum += weight this } @@ -789,8 +825,8 @@ private class LogisticAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { - totalCnt += other.totalCnt + if (other.weightSum != 0.0) { + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -805,13 +841,17 @@ private class LogisticAggregator( this } - def count: Long = totalCnt - - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -823,7 +863,7 @@ private class LogisticAggregator( * It's used in Breeze's convex optimization routines. */ private class LogisticCostFun( - data: RDD[(Double, Vector)], + data: RDD[Instance], numClasses: Int, fitIntercept: Boolean, standardization: Boolean, @@ -831,22 +871,23 @@ private class LogisticCostFun( featuresMean: Array[Double], regParamL2: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { val numFeatures = featuresStd.length - val w = Vectors.fromBreeze(weights) + val w = Vectors.fromBreeze(coefficients) - val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept, - featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + val logisticAggregator = { + val seqOp = (c: LogisticAggregator, instance: Instance) => + c.add(instance.label, instance.features, instance.weight) + val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) + + data.treeAggregate( + new LogisticAggregator(w, numClasses, fitIntercept, featuresStd, featuresMean) + )(seqOp, combOp) + } val totalGradientArray = logisticAggregator.gradient.toArray - // regVal is the sum of weight squares excluding intercept for L2 regularization. + // regVal is the sum of coefficients squares excluding intercept for L2 regularization. val regVal = if (regParamL2 == 0.0) { 0.0 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e9e99ed1db40..8049d51fee5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -42,7 +42,7 @@ private[shared] object SharedParamsCodeGen { Some("\"rawPrediction\"")), ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" + " probabilities. Note: Not all models output well-calibrated probability estimates!" + - " These probabilities should be treated as confidences, not precise probabilities.", + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), @@ -65,10 +65,10 @@ private[shared] object SharedParamsCodeGen { "options may be added later.", isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + - " before fitting the model.", Some("true")), + " before fitting the model", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + - " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 30092170863a..aff47fc326c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -127,10 +127,10 @@ private[ml] trait HasRawPredictionCol extends Params { private[ml] trait HasProbabilityCol extends Params { /** - * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. * @group param */ - final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities") setDefault(probabilityCol, "probability") @@ -270,10 +270,10 @@ private[ml] trait HasHandleInvalid extends Params { private[ml] trait HasStandardization extends Params { /** - * Param for whether to standardize the training features before fitting the model.. + * Param for whether to standardize the training features before fitting the model. * @group param */ - final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.") + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model") setDefault(standardization, true) @@ -304,10 +304,10 @@ private[ml] trait HasSeed extends Params { private[ml] trait HasElasticNetParam extends Params { /** - * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. * @group param */ - final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1)) + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", ParamValidators.inRange(0, 1)) /** @group getParam */ final def getElasticNetParam: Double = $(elasticNetParam) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 51b713e263e0..201333c3690d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -23,16 +23,19 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * :: DeveloperApi :: * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean, - * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector + * variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector * format in a online fashion. * * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of * the corresponding joint dataset. * - * A numerically stable algorithm is implemented to compute sample mean and variance: + * A numerically stable algorithm is implemented to compute the mean and variance of instances: * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. + * + * For weighted instances, the unbiased estimation of variance is defined by the reliability + * weights: [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]. */ @Since("1.1.0") @DeveloperApi @@ -44,6 +47,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var currM2: Array[Double] = _ private var currL1: Array[Double] = _ private var totalCnt: Long = 0 + private var weightSum: Double = 0.0 + private var weightSquareSum: Double = 0.0 private var nnz: Array[Double] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -55,10 +60,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * @return This MultivariateOnlineSummarizer object. */ @Since("1.1.0") - def add(sample: Vector): this.type = { + def add(sample: Vector): this.type = add(sample, 1.0) + + private[spark] def add(instance: Vector, weight: Double): this.type = { + require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0") + if (weight == 0.0) return this + if (n == 0) { - require(sample.size > 0, s"Vector should have dimension larger than zero.") - n = sample.size + require(instance.size > 0, s"Vector should have dimension larger than zero.") + n = instance.size currMean = Array.ofDim[Double](n) currM2n = Array.ofDim[Double](n) @@ -69,8 +79,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMin = Array.fill[Double](n)(Double.MaxValue) } - require(n == sample.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $n but got ${sample.size}.") + require(n == instance.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${instance.size}.") val localCurrMean = currMean val localCurrM2n = currM2n @@ -79,7 +89,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localNnz = nnz val localCurrMax = currMax val localCurrMin = currMin - sample.foreachActive { (index, value) => + instance.foreachActive { (index, value) => if (value != 0.0) { if (localCurrMax(index) < value) { localCurrMax(index) = value @@ -90,15 +100,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val prevMean = localCurrMean(index) val diff = value - prevMean - localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0) - localCurrM2n(index) += (value - localCurrMean(index)) * diff - localCurrM2(index) += value * value - localCurrL1(index) += math.abs(value) + localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight) + localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff + localCurrM2(index) += weight * value * value + localCurrL1(index) += weight * math.abs(value) - localNnz(index) += 1.0 + localNnz(index) += weight } } + weightSum += weight + weightSquareSum += weight * weight totalCnt += 1 this } @@ -112,10 +124,12 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") def merge(other: MultivariateOnlineSummarizer): this.type = { - if (this.totalCnt != 0 && other.totalCnt != 0) { + if (this.weightSum != 0.0 && other.weightSum != 0.0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt + weightSum += other.weightSum + weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { val thisNnz = nnz(i) @@ -138,13 +152,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S nnz(i) = totalNnz i += 1 } - } else if (totalCnt == 0 && other.totalCnt != 0) { + } else if (weightSum == 0.0 && other.weightSum != 0.0) { this.n = other.n this.currMean = other.currMean.clone() this.currM2n = other.currM2n.clone() this.currM2 = other.currM2.clone() this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt + this.weightSum = other.weightSum + this.weightSquareSum = other.weightSquareSum this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -158,28 +174,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def mean: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (nnz(i) / totalCnt) + realMean(i) = currMean(i) * (nnz(i) / weightSum) i += 1 } Vectors.dense(realMean) } /** - * Sample variance of each dimension. + * Unbiased estimate of sample variance of each dimension. * */ @Since("1.1.0") override def variance: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realVariance = Array.ofDim[Double](n) - val denominator = totalCnt - 1.0 + val denominator = weightSum - (weightSquareSum / weightSum) // Sample variance is computed, if the denominator is less than 0, the variance is just 0. if (denominator > 0.0) { @@ -187,9 +203,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = - currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt - realVariance(i) /= denominator + realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * + (weightSum - nnz(i)) / weightSum) / denominator i += 1 } } @@ -209,7 +224,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def numNonzeros: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } @@ -220,11 +235,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def max: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } Vectors.dense(currMax) @@ -236,11 +251,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def min: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } Vectors.dense(currMin) @@ -252,7 +267,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL2: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMagnitude = Array.ofDim[Double](n) @@ -271,7 +286,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL1: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(currL1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index cce39f382f73..f5219f9f574b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.classification +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -59,8 +62,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) - sqlContext.createDataFrame( - generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)) + sqlContext.createDataFrame(sc.parallelize(testData, 4)) } } @@ -77,6 +79,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol === "prediction") assert(lr.getRawPredictionCol === "rawPrediction") assert(lr.getProbabilityCol === "probability") + assert(lr.getWeightCol === "") assert(lr.getFitIntercept) assert(lr.getStandardization) val model = lr.fit(dataset) @@ -216,43 +219,65 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("MultiClassSummarizer") { val summarizer1 = (new MultiClassSummarizer) .add(0.0).add(3.0).add(4.0).add(3.0).add(6.0) - assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2)) + assert(summarizer1.histogram === Array[Double](1, 0, 0, 2, 1, 0, 1)) assert(summarizer1.countInvalid === 0) assert(summarizer1.numClasses === 7) val summarizer2 = (new MultiClassSummarizer) .add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0) - assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer2.histogram === Array[Double](1, 2, 0, 1, 1, 1)) assert(summarizer2.countInvalid === 0) assert(summarizer2.numClasses === 6) val summarizer3 = (new MultiClassSummarizer) .add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0) - assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2)) + assert(summarizer3.histogram === Array[Double](1, 1, 1, 0, 3)) assert(summarizer3.countInvalid === 3) assert(summarizer3.numClasses === 5) val summarizer4 = (new MultiClassSummarizer) .add(3.1).add(4.3).add(2.0).add(1.0).add(3.0) - assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer4.histogram === Array[Double](0, 1, 1, 1)) assert(summarizer4.countInvalid === 2) assert(summarizer4.numClasses === 4) // small map merges large one val summarizerA = summarizer1.merge(summarizer2) assert(summarizerA.hashCode() === summarizer2.hashCode()) - assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizerA.histogram === Array[Double](2, 2, 0, 3, 2, 1, 1)) assert(summarizerA.countInvalid === 0) assert(summarizerA.numClasses === 7) // large map merges small one val summarizerB = summarizer3.merge(summarizer4) assert(summarizerB.hashCode() === summarizer3.hashCode()) - assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2)) + assert(summarizerB.histogram === Array[Double](1, 2, 2, 1, 3)) assert(summarizerB.countInvalid === 5) assert(summarizerB.numClasses === 5) } + test("MultiClassSummarizer with weighted samples") { + val summarizer1 = (new MultiClassSummarizer) + .add(label = 0.0, weight = 0.2).add(3.0, 0.8).add(4.0, 3.2).add(3.0, 1.3).add(6.0, 3.1) + assert(Vectors.dense(summarizer1.histogram) ~== + Vectors.dense(Array(0.2, 0, 0, 2.1, 3.2, 0, 3.1)) absTol 1E-10) + assert(summarizer1.countInvalid === 0) + assert(summarizer1.numClasses === 7) + + val summarizer2 = (new MultiClassSummarizer) + .add(1.0, 1.1).add(5.0, 2.3).add(3.0).add(0.0).add(4.0).add(1.0).add(2, 0.0) + assert(Vectors.dense(summarizer2.histogram) ~== + Vectors.dense(Array[Double](1.0, 2.1, 0.0, 1, 1, 2.3)) absTol 1E-10) + assert(summarizer2.countInvalid === 0) + assert(summarizer2.numClasses === 6) + + val summarizer = summarizer1.merge(summarizer2) + assert(Vectors.dense(summarizer.histogram) ~== + Vectors.dense(Array(1.2, 2.1, 0.0, 3.1, 4.2, 2.3, 3.1)) absTol 1E-10) + assert(summarizer.countInvalid === 0) + assert(summarizer.numClasses === 7) + } + test("binary logistic regression with intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true) val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false) @@ -713,7 +738,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) + val interceptTheory = math.log(histogram(1) / histogram(0)) val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptTheory relTol 1E-5) @@ -781,4 +806,63 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .forall(x => x(0) >= x(1))) } + + test("binary logistic regression with weighted samples") { + val (dataset, weightedDataset) = { + val nPoints = 1000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + + // Let's over-sample the positive samples twice. + val data1 = testData.flatMap { case labeledPoint: LabeledPoint => + if (labeledPoint.label == 1.0) { + Iterator(labeledPoint, labeledPoint) + } else { + Iterator(labeledPoint) + } + } + + val rnd = new Random(8392) + val data2 = testData.flatMap { case LabeledPoint(label: Double, features: Vector) => + if (rnd.nextGaussian() > 0.0) { + if (label == 1.0) { + Iterator( + Instance(label, 1.2, features), + Instance(label, 0.8, features), + Instance(0.0, 0.0, features)) + } else { + Iterator( + Instance(label, 0.3, features), + Instance(1.0, 0.0, features), + Instance(label, 0.1, features), + Instance(label, 0.6, features)) + } + } else { + if (label == 1.0) { + Iterator(Instance(label, 2.0, features)) + } else { + Iterator(Instance(label, 1.0, features)) + } + } + } + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LogisticRegression).setFitIntercept(true) + .setRegParam(0.0).setStandardization(true) + val trainer1b = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") + .setRegParam(0.0).setStandardization(true) + val model1a0 = trainer1a.fit(dataset) + val model1a1 = trainer1a.fit(weightedDataset) + val model1b = trainer1b.fit(weightedDataset) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 07efde4f5e6d..b6d41db69be0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -218,4 +218,31 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { s0.merge(s1) assert(s0.mean(0) ~== 1.0 absTol 1e-14) } + + test("merging summarizer with weighted samples") { + val summarizer = (new MultivariateOnlineSummarizer) + .add(instance = Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1) + .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge( + (new MultivariateOnlineSummarizer) + .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15) + .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05)) + + assert(summarizer.count === 4) + + // The following values are hand calculated using the formula: + // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + // which defines the reliability weight used for computing the unbiased estimation of variance + // for weighted instances. + assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44)) + absTol 1E-10, "mean mismatch") + assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857)) + absTol 1E-8, "variance mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4)) + absTol 1E-10, "numNonzeros mismatch") + assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch") + assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch") + assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192) + absTol 1E-8, "normL2 mismatch") + assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch") + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 87b141cd3b05..46026c1e90ea 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,15 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticAggregator.count") + ) case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), From b6e998634e05db0bb6267173e7b28f885c808c16 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 16:45:47 -0700 Subject: [PATCH 0639/1824] [SPARK-10548] [SPARK-10563] [SQL] Fix concurrent SQL executions *Note: this is for master branch only.* The fix for branch-1.5 is at #8721. The query execution ID is currently passed from a thread to its children, which is not the intended behavior. This led to `IllegalArgumentException: spark.sql.execution.id is already set` when running queries in parallel, e.g.: ``` (1 to 100).par.foreach { _ => sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() } ``` The cause is `SparkContext`'s local properties are inherited by default. This patch adds a way to exclude keys we don't want to be inherited, and makes SQL go through that code path. Author: Andrew Or Closes #8710 from andrewor14/concurrent-sql-executions. --- .../scala/org/apache/spark/SparkContext.scala | 9 +- .../org/apache/spark/ThreadingSuite.scala | 65 +++++------ .../sql/execution/SQLExecutionSuite.scala | 101 ++++++++++++++++++ 3 files changed, 132 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dee6091ce3ca..a2f34eafa2c3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -33,6 +33,7 @@ import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} import scala.util.control.NonFatal +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -347,8 +348,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + protected[spark] val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = { + // Note: make a clone such that changes in the parent properties aren't reflected in + // the those of the children threads, which has confusing semantics (SPARK-10563). + SerializationUtils.clone(parent).asInstanceOf[Properties] + } override protected def initialValue(): Properties = new Properties() } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index a96a4ce201c2..54c131cdae36 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -147,7 +147,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { }.start() } sem.acquire(2) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } if (ThreadingSuiteState.failed.get()) { logError("Waited 1 second without seeing runningThreads = 4 (it was " + ThreadingSuiteState.runningThreads.get() + "); failing test") @@ -178,7 +178,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === null) } @@ -207,58 +207,41 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) } - test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { - val jobStarted = new Semaphore(0) - val jobEnded = new Semaphore(0) - @volatile var jobResult: JobResult = null - var throwable: Option[Throwable] = None - + test("mutation in parent local property does not affect child (SPARK-10563)") { sc = new SparkContext("local", "test") - sc.setJobGroup("originalJobGroupId", "description") - sc.addSparkListener(new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - jobStarted.release() - } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - jobResult = jobEnd.jobResult - jobEnded.release() - } - }) - - // Create a new thread which will inherit the current thread's properties - val thread = new Thread() { + val originalTestValue: String = "original-value" + var threadTestValue: String = null + sc.setLocalProperty("test", originalTestValue) + var throwable: Option[Throwable] = None + val thread = new Thread { override def run(): Unit = { try { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task - try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) - } - } catch { - case s: SparkException => // ignored so that we don't print noise in test logs - } + threadTestValue = sc.getLocalProperty("test") } catch { case t: Throwable => throwable = Some(t) } } } + sc.setLocalProperty("test", "this-should-not-be-inherited") thread.start() - // Wait for the job to start, then mutate the original properties, which should have been - // inherited by the running job but hopefully defensively copied or snapshotted: - jobStarted.tryAcquire(10, TimeUnit.SECONDS) - sc.setJobGroup("modifiedJobGroupId", "description") - // Canceling the original job group should cancel the running job. In other words, the - // modification of the properties object should not affect the properties of running jobs - sc.cancelJobGroup("originalJobGroupId") - jobEnded.tryAcquire(10, TimeUnit.SECONDS) - throwable.foreach { t => throw t } - assert(jobResult.isInstanceOf[JobFailed]) + thread.join() + throwable.foreach { t => throw improveStackTrace(t) } + assert(threadTestValue === originalTestValue) } + + /** + * Improve the stack trace of an error thrown from within a thread. + * Otherwise it's difficult to tell which line in the test the error came from. + */ + private def improveStackTrace(t: Throwable): Throwable = { + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + t + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala new file mode 100644 index 000000000000..63639681ef80 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.Properties + +import scala.collection.parallel.CompositeThrowable + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext + +class SQLExecutionSuite extends SparkFunSuite { + + test("concurrent query execution (SPARK-10548)") { + // Try to reproduce the issue with the old SparkContext + val conf = new SparkConf() + .setMaster("local[*]") + .setAppName("test") + val badSparkContext = new BadSparkContext(conf) + try { + testConcurrentQueryExecution(badSparkContext) + fail("unable to reproduce SPARK-10548") + } catch { + case e: IllegalArgumentException => + assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) + } finally { + badSparkContext.stop() + } + + // Verify that the issue is fixed with the latest SparkContext + val goodSparkContext = new SparkContext(conf) + try { + testConcurrentQueryExecution(goodSparkContext) + } finally { + goodSparkContext.stop() + } + } + + /** + * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. + */ + private def testConcurrentQueryExecution(sc: SparkContext): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Initialize local properties. This is necessary for the test to pass. + sc.getLocalProperties + + // Set up a thread that runs executes a simple SQL query. + // Before starting the thread, mutate the execution ID in the parent. + // The child thread should not see the effect of this change. + var throwable: Option[Throwable] = None + val child = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100).map { i => (i, i) }.toDF("a", "b").collect() + } catch { + case t: Throwable => + throwable = Some(t) + } + + } + } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "anything") + child.start() + child.join() + + // The throwable is thrown from the child thread so it doesn't have a helpful stack trace + throwable.foreach { t => + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + throw t + } + } + +} + +/** + * A bad [[SparkContext]] that does not clone the inheritable thread local properties + * when passing them to children threads. + */ +private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { + protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } +} From a63cdc769f511e98b38c3318bcc732c9a6c76c22 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Sep 2015 16:53:27 -0700 Subject: [PATCH 0640/1824] [SPARK-10612] [SQL] Add prepare to LocalNode. The idea is that we should separate the function call that does memory reservation (i.e. prepare) from the function call that consumes the input (e.g. open()), so all operators can be a chance to reserve memory before they are all consumed. Author: Reynold Xin Closes #8761 from rxin/SPARK-10612. --- .../org/apache/spark/sql/execution/local/LocalNode.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 9840080e1695..569cff565c09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -45,6 +45,14 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging def output: Seq[Attribute] + /** + * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume + * any input data. + * + * Implementations of this must also call the `prepare()` function of its children. + */ + def prepare(): Unit = children.foreach(_.prepare()) + /** * Initializes the iterator state. Must be called before calling `next()`. * From 99ecfa5945aedaa71765ecf5cce59964ae52eebe Mon Sep 17 00:00:00 2001 From: vinodkc Date: Tue, 15 Sep 2015 17:01:10 -0700 Subject: [PATCH 0641/1824] [SPARK-10575] [SPARK CORE] Wrapped RDD.takeSample with Scope Remove return statements in RDD.takeSample and wrap it withScope Author: vinodkc Author: vinodkc Author: Vinod K C Closes #8730 from vinodkc/fix_takesample_return. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 68 +++++++++---------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7dd2bc5d7cd7..a56e542242d5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -469,50 +469,44 @@ abstract class RDD[T: ClassTag]( * @param seed seed for the random number generator * @return sample of specified size in an array */ - // TODO: rewrite this without return statements so we can wrap it in a scope def takeSample( withReplacement: Boolean, num: Int, - seed: Long = Utils.random.nextLong): Array[T] = { + seed: Long = Utils.random.nextLong): Array[T] = withScope { val numStDev = 10.0 - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } else if (num == 0) { - return new Array[T](0) - } - - val initialCount = this.count() - if (initialCount == 0) { - return new Array[T](0) - } - - val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt - if (num > maxSampleSize) { - throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + - s"$numStDev * math.sqrt(Int.MaxValue)") - } - - val rand = new Random(seed) - if (!withReplacement && num >= initialCount) { - return Utils.randomizeInPlace(this.collect(), rand) - } - - val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, - withReplacement) - - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + require(num >= 0, "Negative number of elements requested") + require(num <= (Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt), + "Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for the initial size - var numIters = 0 - while (samples.length < num) { - logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - numIters += 1 + if (num == 0) { + new Array[T](0) + } else { + val initialCount = this.count() + if (initialCount == 0) { + new Array[T](0) + } else { + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + Utils.randomizeInPlace(this.collect(), rand) + } else { + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + // If the first sample didn't turn out large enough, keep trying to take samples; + // this shouldn't happen often because we use a big multiplier for the initial size + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") + samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 + } + Utils.randomizeInPlace(samples, rand).take(num) + } + } } - - Utils.randomizeInPlace(samples, rand).take(num) } /** From 38700ea40cb1dd0805cc926a9e629f93c99527ad Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 15 Sep 2015 17:11:21 -0700 Subject: [PATCH 0642/1824] [SPARK-10381] Fix mixup of taskAttemptNumber & attemptId in OutputCommitCoordinator When speculative execution is enabled, consider a scenario where the authorized committer of a particular output partition fails during the OutputCommitter.commitTask() call. In this case, the OutputCommitCoordinator is supposed to release that committer's exclusive lock on committing once that task fails. However, due to a unit mismatch (we used task attempt number in one place and task attempt id in another) the lock will not be released, causing Spark to go into an infinite retry loop. This bug was masked by the fact that the OutputCommitCoordinator does not have enough end-to-end tests (the current tests use many mocks). Other factors contributing to this bug are the fact that we have many similarly-named identifiers that have different semantics but the same data types (e.g. attemptNumber and taskAttemptId, with inconsistent variable naming which makes them difficult to distinguish). This patch adds a regression test and fixes this bug by always using task attempt numbers throughout this code. Author: Josh Rosen Closes #8544 from JoshRosen/SPARK-10381. --- .../org/apache/spark/SparkHadoopWriter.scala | 3 +- .../org/apache/spark/TaskEndReason.scala | 7 +- .../executor/CommitDeniedException.scala | 4 +- .../spark/mapred/SparkHadoopMapRedUtil.scala | 20 ++---- .../apache/spark/scheduler/DAGScheduler.scala | 7 +- .../scheduler/OutputCommitCoordinator.scala | 48 +++++++------ .../org/apache/spark/scheduler/TaskInfo.scala | 7 +- .../status/api/v1/AllStagesResource.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 4 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- ...putCommitCoordinatorIntegrationSuite.scala | 68 +++++++++++++++++++ .../OutputCommitCoordinatorSuite.scala | 24 ++++--- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- project/MimaExcludes.scala | 36 +++++++++- .../datasources/WriterContainer.scala | 3 +- .../sql/execution/ui/SQLListenerSuite.scala | 4 +- .../spark/sql/hive/hiveWriterContainers.scala | 2 +- 17 files changed, 174 insertions(+), 69 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index ae5926dd534a..ac6eaab20d8d 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -104,8 +104,7 @@ class SparkHadoopWriter(jobConf: JobConf) } def commit() { - SparkHadoopMapRedUtil.commitTask( - getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 2ae878b3e608..7137246bc34f 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -193,9 +193,12 @@ case object TaskKilled extends TaskFailedReason { * Task requested the driver to commit, but was denied. */ @DeveloperApi -case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { +case class TaskCommitDenied( + jobID: Int, + partitionID: Int, + attemptNumber: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + - s" for job: $jobID, partition: $partitionID, attempt: $attemptID" + s" for job: $jobID, partition: $partitionID, attemptNumber: $attemptNumber" /** * If a task failed because its attempt to commit was denied, do not count this failure * towards failing the stage. This is intended to prevent spurious stage failures in cases diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index f47d7ef511da..7d84889a2def 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -26,8 +26,8 @@ private[spark] class CommitDeniedException( msg: String, jobID: Int, splitID: Int, - attemptID: Int) + attemptNumber: Int) extends Exception(msg) { - def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber) } diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index f405b732e472..f7298e8d5c62 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -91,8 +91,7 @@ object SparkHadoopMapRedUtil extends Logging { committer: MapReduceOutputCommitter, mrTaskContext: MapReduceTaskAttemptContext, jobId: Int, - splitId: Int, - attemptId: Int): Unit = { + splitId: Int): Unit = { val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) @@ -122,7 +121,8 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + val taskAttemptNumber = TaskContext.get().attemptNumber() + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber) if (canCommit) { performCommit() @@ -132,7 +132,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, jobId, splitId, attemptId) + throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination @@ -143,16 +143,4 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") } } - - def commitTask( - committer: MapReduceOutputCommitter, - mrTaskContext: MapReduceTaskAttemptContext, - sparkTaskContext: TaskContext): Unit = { - commitTask( - committer, - mrTaskContext, - sparkTaskContext.stageId(), - sparkTaskContext.partitionId(), - sparkTaskContext.attemptNumber()) - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b4f90e834789..3c9a66e50440 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1128,8 +1128,11 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - outputCommitCoordinator.taskCompleted(stageId, task.partitionId, - event.taskInfo.attempt, event.reason) + outputCommitCoordinator.taskCompleted( + stageId, + task.partitionId, + event.taskInfo.attemptNumber, // this is a task attempt number + event.reason) // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 5d926377ce86..add0dedc03f4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) +private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -44,8 +44,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int - private type PartitionId = Long - private type TaskAttemptId = Long + private type PartitionId = Int + private type TaskAttemptNumber = Int /** * Map from active stages's id => partition id => task attempt with exclusive lock on committing @@ -57,7 +57,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() - private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + private type CommittersByStageMap = + mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptNumber]] /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -75,14 +76,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * @param stage the stage number * @param partition the partition number - * @param attempt a unique identifier for this task attempt + * @param attemptNumber how many times this task has been attempted + * (see [[TaskContext.attemptNumber()]]) * @return true if this task is authorized to commit, false otherwise */ def canCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attempt) + attemptNumber: TaskAttemptNumber): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => endpointRef.askWithRetry[Boolean](msg) @@ -95,7 +97,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Called by DAGScheduler private[scheduler] def stageStart(stage: StageId): Unit = synchronized { - authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptNumber]() } // Called by DAGScheduler @@ -107,7 +109,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def taskCompleted( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId, + attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -117,12 +119,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) case Success => // The task output has been committed successfully case denied: TaskCommitDenied => - logInfo( - s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters.get(partition).exists(_ == attempt)) { - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") + if (authorizedCommitters.get(partition).exists(_ == attemptNumber)) { + logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + + s"partition=$partition) failed; clearing lock") authorizedCommitters.remove(partition) } } @@ -140,21 +142,23 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def handleAskPermissionToCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = synchronized { + attemptNumber: TaskAttemptNumber): Boolean = synchronized { authorizedCommittersByStage.get(stage) match { case Some(authorizedCommitters) => authorizedCommitters.get(partition) match { case Some(existingCommitter) => - logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + - s"existingCommitter = $existingCommitter") + logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition; existingCommitter = $existingCommitter") false case None => - logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") - authorizedCommitters(partition) = attempt + logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition") + authorizedCommitters(partition) = attemptNumber true } case None => - logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + + s"partition $partition to commit") false } } @@ -174,9 +178,9 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + case AskPermissionToCommitOutput(stage, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 132a9ced7770..f113c2b1b843 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi class TaskInfo( val taskId: Long, val index: Int, - val attempt: Int, + val attemptNumber: Int, val launchTime: Long, val executorId: String, val host: String, @@ -95,7 +95,10 @@ class TaskInfo( } } - def id: String = s"$index.$attempt" + @deprecated("Use attemptNumber", "1.6.0") + def attempt: Int = attemptNumber + + def id: String = s"$index.$attemptNumber" def duration: Long = { if (!finished) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 390c136df79b..24a0b5220695 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -127,7 +127,7 @@ private[v1] object AllStagesResource { new TaskData( taskId = uiData.taskInfo.taskId, index = uiData.taskInfo.index, - attempt = uiData.taskInfo.attempt, + attempt = uiData.taskInfo.attemptNumber, launchTime = new Date(uiData.taskInfo.launchTime), executorId = uiData.taskInfo.executorId, host = uiData.taskInfo.host, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2b71f55b7bb4..712782d27b3c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -621,7 +621,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { serializationTimeProportionPos + serializationTimeProportion val index = taskInfo.index - val attempt = taskInfo.attempt + val attempt = taskInfo.attemptNumber val svgTag = if (totalExecutionTime == 0) { @@ -967,7 +967,7 @@ private[ui] class TaskDataSource( new TaskTableRowData( info.index, info.taskId, - info.attempt, + info.attemptNumber, info.speculative, info.status, info.taskLocality.toString, diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 24f78744ad74..99614a786bd9 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -266,7 +266,7 @@ private[spark] object JsonProtocol { def taskInfoToJson(taskInfo: TaskInfo): JValue = { ("Task ID" -> taskInfo.taskId) ~ ("Index" -> taskInfo.index) ~ - ("Attempt" -> taskInfo.attempt) ~ + ("Attempt" -> taskInfo.attemptNumber) ~ ("Launch Time" -> taskInfo.launchTime) ~ ("Executor ID" -> taskInfo.executorId) ~ ("Host" -> taskInfo.host) ~ diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala new file mode 100644 index 000000000000..1ae5b030f083 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.{Span, Seconds} + +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} +import org.apache.spark.util.Utils + +/** + * Integration tests for the OutputCommitCoordinator. + * + * See also: [[OutputCommitCoordinatorSuite]] for unit tests that use mocks. + */ +class OutputCommitCoordinatorIntegrationSuite + extends SparkFunSuite + with LocalSparkContext + with Timeouts { + + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .set("master", "local[2,4]") + .set("spark.speculation", "true") + .set("spark.hadoop.mapred.output.committer.class", + classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) + sc = new SparkContext("local[2, 4]", "test", conf) + } + + test("exception thrown in OutputCommitter.commitTask()") { + // Regression test for SPARK-10381 + failAfter(Span(60, Seconds)) { + val tempDir = Utils.createTempDir() + try { + sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out") + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} + +private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter { + override def commitTask(context: TaskAttemptContext): Unit = { + val ctx = TaskContext.get() + if (ctx.attemptNumber < 1) { + throw new java.io.FileNotFoundException("Intentional exception") + } + super.commitTask(context) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index e5ecd4b7c261..6d08d7c5b7d2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -63,6 +63,9 @@ import scala.language.postfixOps * was not in SparkHadoopWriter, the tests would still pass because only one of the * increments would be captured even though the commit in both tasks was executed * erroneously. + * + * See also: [[OutputCommitCoordinatorIntegrationSuite]] for integration tests that do + * not use mocks. */ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { @@ -164,27 +167,28 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 - val partition: Long = 2 - val authorizedCommitter: Long = 3 - val nonAuthorizedCommitter: Long = 100 + val partition: Int = 2 + val authorizedCommitter: Int = 3 + val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage) - assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + + assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) // New tasks should still not be able to commit because the authorized committer has not failed assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) // A new task should now be allowed to become the authorized committer assert( - outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) // There can only be one authorized committer assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 47e548ef0d44..143c1b901df1 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -499,7 +499,7 @@ class JsonProtocolSuite extends SparkFunSuite { private def assertEquals(info1: TaskInfo, info2: TaskInfo) { assert(info1.taskId === info2.taskId) assert(info1.index === info2.index) - assert(info1.attempt === info2.attempt) + assert(info1.attemptNumber === info2.attemptNumber) assert(info1.launchTime === info2.launchTime) assert(info1.executorId === info2.executorId) assert(info1.host === info2.host) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 46026c1e90ea..1c96b0958586 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,7 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ Seq( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticCostFun.this"), @@ -53,6 +53,23 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticAggregator.count") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") ) case v if v.startsWith("1.5") => Seq( @@ -213,6 +230,23 @@ object MimaExcludes { // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.mllib.linalg.VectorUDT.serialize") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") ) case v if v.startsWith("1.4") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index f8ef674ed29c..cfd64c1d9eb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -198,8 +198,7 @@ private[sql] abstract class BaseWriterContainer( } def commitTask(): Unit = { - SparkHadoopMapRedUtil.commitTask( - outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) + SparkHadoopMapRedUtil.commitTask(outputCommitter, taskAttemptContext, jobId.getId, taskId.getId) } def abortTask(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 2bbb41ca777b..7a46c69a056b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -54,9 +54,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { details = "" ) - private def createTaskInfo(taskId: Int, attempt: Int): TaskInfo = new TaskInfo( + private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new TaskInfo( taskId = taskId, - attempt = attempt, + attemptNumber = attemptNumber, // The following fields are not used in tests index = 0, launchTime = 0, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 4ca8042d2236..c8d6b718045a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -121,7 +121,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { From 35a19f3357d2ec017cfefb90f1018403e9617de4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 17:24:32 -0700 Subject: [PATCH 0643/1824] [SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on SQLContext Instead of relying on `DataFrames` to verify our answers, we can just use simple arrays. This significantly simplifies the test logic for `LocalNode`s and reduces a lot of code duplicated from `SparkPlanTest`. This also fixes an additional issue [SPARK-10624](https://issues.apache.org/jira/browse/SPARK-10624) where the output of `TakeOrderedAndProjectNode` is not actually ordered. Author: Andrew Or Closes #8764 from andrewor14/sql-local-tests-cleanup. --- .../spark/sql/execution/local/LocalNode.scala | 8 +- .../sql/execution/local/SampleNode.scala | 16 +- .../local/TakeOrderedAndProjectNode.scala | 2 +- .../spark/sql/execution/SparkPlanTest.scala | 2 +- .../spark/sql/execution/local/DummyNode.scala | 68 ++++ .../sql/execution/local/ExpandNodeSuite.scala | 54 ++- .../sql/execution/local/FilterNodeSuite.scala | 34 +- .../execution/local/HashJoinNodeSuite.scala | 141 ++++---- .../execution/local/IntersectNodeSuite.scala | 24 +- .../sql/execution/local/LimitNodeSuite.scala | 28 +- .../sql/execution/local/LocalNodeSuite.scala | 73 +--- .../sql/execution/local/LocalNodeTest.scala | 165 ++------- .../local/NestedLoopJoinNodeSuite.scala | 316 ++++++------------ .../execution/local/ProjectNodeSuite.scala | 39 ++- .../sql/execution/local/SampleNodeSuite.scala | 35 +- .../TakeOrderedAndProjectNodeSuite.scala | 50 ++- .../sql/execution/local/UnionNodeSuite.scala | 49 +-- 17 files changed, 468 insertions(+), 636 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 569cff565c09..f96b62a67a25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.types.StructType /** @@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType * Before consuming the iterator, open function must be called. * After consuming the iterator, close function must be called. */ -abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging { +abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { protected val codegenEnabled: Boolean = conf.codegenEnabled protected val unsafeEnabled: Boolean = conf.unsafeEnabled - lazy val schema: StructType = StructType.fromAttributes(output) - private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") - def output: Seq[Attribute] - /** * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume * any input data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala index abf3df1c0c2a..793700803f21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.local -import java.util.Random - import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + /** * Sample the dataset. * @@ -51,18 +50,15 @@ case class SampleNode( override def open(): Unit = { child.open() - val (sampler, _seed) = if (withReplacement) { - val random = new Random(seed) + val sampler = + if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, // requiring us to copy the row, which is more expensive than the random number generator. - (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), - // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result - // of DataFrame - random.nextLong()) + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false) } else { - (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + new BernoulliCellSampler[InternalRow](lowerBound, upperBound) } - sampler.setSeed(_seed) + sampler.setSeed(seed) iterator = sampler.sample(child.asIterator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index 53f1dcc65d8c..ae672fbca8d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode( } // Close it eagerly since we don't need it. child.close() - iterator = queue.iterator + iterator = queue.toArray.sorted(ord).iterator } override def next(): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index de45ae4635fb..3d218f01c9ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -238,7 +238,7 @@ object SparkPlanTest { outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan.transformExpressions { + plan transformExpressions { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala new file mode 100644 index 000000000000..efc3227dd60d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. + */ +private[local] case class DummyNode( + output: Seq[Attribute], + relation: LocalRelation, + conf: SQLConf) + extends LocalNode(conf) { + + import DummyNode._ + + private var index: Int = CLOSED + private val input: Seq[InternalRow] = relation.data + + def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { + this(output, LocalRelation.fromProduct(output, data), conf) + } + + def isOpen: Boolean = index != CLOSED + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + input(index) + } + + override def close(): Unit = { + index = CLOSED + } +} + +private object DummyNode { + val CLOSED: Int = Int.MinValue +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala index cfa7f3f6dcb9..bbd94d8da2d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -17,35 +17,33 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.catalyst.dsl.expressions._ + + class ExpandNodeSuite extends LocalNodeTest { - import testImplicits._ - - test("expand") { - val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value") - checkAnswer( - input, - node => - ExpandNode(conf, Seq( - Seq( - input.col("key") + input.col("value"), input.col("key") - input.col("value") - ).map(_.expr), - Seq( - input.col("key") * input.col("value"), input.col("key") / input.col("value") - ).map(_.expr) - ), node.output, node), - Seq( - (2, 0), - (1, 1), - (4, 0), - (4, 1), - (6, 0), - (9, 1), - (8, 0), - (16, 1), - (10, 0), - (25, 1) - ).toDF().collect() - ) + private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) + val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) + val resolvedNode = resolveExpressions(expandNode) + val expectedOutput = { + val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } + val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } + firstHalf ++ secondHalf + } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test("empty") { + testExpand() } + + test("basic") { + testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index a12670e347c2..4eadce646d37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -17,25 +17,29 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.dsl.expressions._ -class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val condition = (testData.col("key") % 2) === 0 - checkAnswer( - testData, - node => FilterNode(conf, condition.expr, node), - testData.filter(condition).collect() - ) +class FilterNodeSuite extends LocalNodeTest { + + private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val cond = 'k % 2 === 0 + val inputNode = new DummyNode(kvIntAttributes, inputData) + val filterNode = new FilterNode(conf, cond, inputNode) + val resolvedNode = resolveExpressions(filterNode) + val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val condition = (emptyTestData.col("key") % 2) === 0 - checkAnswer( - emptyTestData, - node => FilterNode(conf, condition.expr, node), - emptyTestData.filter(condition).collect() - ) + testFilter() + } + + test("basic") { + testFilter((1 to 100).map { i => (i, i) }.toArray) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 78d891351f4a..5c1bdb088eee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -18,99 +18,80 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.execution.joins +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + class HashJoinNodeSuite extends LocalNodeTest { - import testImplicits._ + // Test all combinations of the two dimensions: with/out unsafe and build sides + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + testJoin(unsafeAndCodegen, buildSide) + } + } - def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { - test(s"$suiteName: inner join with one match per row") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(upperCaseData.col("N").expr), - Seq(lowerCaseData.col("n").expr), - joins.BuildLeft, - node1, - node2) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N").collect() - ) + /** + * Test inner hash join with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { + val rightInputMap = rightInput.toMap + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions(new HashJoinNode( + conf, Seq('id1), Seq('id2), buildSide, node1, node2)) + } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = leftInput + .filter { case (k, _) => rightInputMap.contains(k) } + .map { case (k, v) => (k, v, k, rightInputMap(k)) } + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) } + assert(actualOutput === expectedOutput) } - test(s"$suiteName: inner join with multiple matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 1).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - x.join(y).where($"x.a" === $"y.a").collect() - ) - } + test(s"$testNamePrefix: empty") { + runTest(Array.empty, Array.empty) + runTest(someData, Array.empty) + runTest(Array.empty, someData) } - test(s"$suiteName: inner join, no matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 2).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - Nil - ) - } + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray + runTest(someData, Array.empty) + runTest(Array.empty, someData) + runTest(someData, someIrrelevantData) + runTest(someIrrelevantData, someData) } - test(s"$suiteName: big inner join, 4 matches per row") { - withSQLConf(confPairs: _*) { - val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as("x") - val bigDataY = bigData.as("y") + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(someData, someOtherData) + runTest(someOtherData, someData) + } - checkAnswer2( - bigDataX, - bigDataY, - wrapForUnsafe( - (node1, node2) => - HashJoinNode( - conf, - Seq(bigDataX.col("key").expr), - Seq(bigDataY.col("key").expr), - joins.BuildLeft, - node1, - node2) - ), - bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect()) - } + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray + runTest(someData, someSuperRelevantData) + runTest(someSuperRelevantData, someData) } } - joinSuite( - "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala index 7deaa375fcfc..c0ad2021b204 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -17,19 +17,21 @@ package org.apache.spark.sql.execution.local -class IntersectNodeSuite extends LocalNodeTest { - import testImplicits._ +class IntersectNodeSuite extends LocalNodeTest { test("basic") { - val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") - - checkAnswer2( - input1, - input2, - (node1, node2) => IntersectNode(conf, node1, node2), - input1.intersect(input2).collect() - ) + val n = 100 + val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray + val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray + val leftNode = new DummyNode(kvIntAttributes, leftData) + val rightNode = new DummyNode(kvIntAttributes, rightData) + val intersectNode = new IntersectNode(conf, leftNode, rightNode) + val expectedOutput = leftData.intersect(rightData) + val actualOutput = intersectNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 3b183902007e..fb790636a368 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { +class LimitNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer( - testData, - node => LimitNode(conf, 10, node), - testData.limit(10).collect() - ) + private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val limitNode = new LimitNode(conf, limit, inputNode) + val expectedOutput = inputData.take(limit) + val actualOutput = limitNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer( - emptyTestData, - node => LimitNode(conf, 10, node), - emptyTestData.limit(10).collect() - ) + testLimit() } + + test("basic") { + testLimit((1 to 100).map { i => (i, i) }.toArray, 20) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala index b89fa46f8b3b..0d1ed99eec6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -17,28 +17,24 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.IntegerType -class LocalNodeSuite extends SparkFunSuite { - private val data = (1 to 100).toArray +class LocalNodeSuite extends LocalNodeTest { + private val data = (1 to 100).map { i => (i, i) }.toArray test("basic open, next, fetch, close") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) assert(!node.isOpen) node.open() assert(node.isOpen) - data.foreach { i => + data.foreach { case (k, v) => assert(node.next()) // fetch should be idempotent val fetched = node.fetch() assert(node.fetch() === fetched) assert(node.fetch() === fetched) - assert(node.fetch().numFields === 1) - assert(node.fetch().getInt(0) === i) + assert(node.fetch().numFields === 2) + assert(node.fetch().getInt(0) === k) + assert(node.fetch().getInt(1) === v) } assert(!node.next()) node.close() @@ -46,16 +42,17 @@ class LocalNodeSuite extends SparkFunSuite { } test("asIterator") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) val iter = node.asIterator node.open() - data.foreach { i => + data.foreach { case (k, v) => // hasNext should be idempotent assert(iter.hasNext) assert(iter.hasNext) val item = iter.next() - assert(item.numFields === 1) - assert(item.getInt(0) === i) + assert(item.numFields === 2) + assert(item.getInt(0) === k) + assert(item.getInt(1) === v) } intercept[NoSuchElementException] { iter.next() @@ -64,53 +61,13 @@ class LocalNodeSuite extends SparkFunSuite { } test("collect") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) node.open() val collected = node.collect() assert(collected.size === data.size) - assert(collected.forall(_.size === 1)) - assert(collected.map(_.getInt(0)) === data) + assert(collected.forall(_.size === 2)) + assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) node.close() } } - -/** - * A dummy [[LocalNode]] that just returns one row per integer in the input. - */ -private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) { - private var index = Int.MinValue - - def this(input: Array[Int]) { - this(new SQLConf, input) - } - - def isOpen: Boolean = { - index != Int.MinValue - } - - override def output: Seq[Attribute] = { - Seq(AttributeReference("something", IntegerType)()) - } - - override def children: Seq[LocalNode] = Seq.empty - - override def open(): Unit = { - index = -1 - } - - override def next(): Boolean = { - index += 1 - index < input.size - } - - override def fetch(): InternalRow = { - assert(index >= 0 && index < input.size) - val values = Array(input(index).asInstanceOf[Any]) - new GenericInternalRow(values) - } - - override def close(): Unit = { - index = Int.MinValue - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 86dd28064cc6..098050bcd223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -17,147 +17,54 @@ package org.apache.spark.sql.execution.local -import scala.util.control.NonFatal - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLConf} -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.{IntegerType, StringType} -class LocalNodeTest extends SparkFunSuite with SharedSQLContext { - def conf: SQLConf = sqlContext.conf +class LocalNodeTest extends SparkFunSuite { - protected def wrapForUnsafe( - f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - if (conf.unsafeEnabled) { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } else { - f - } - } - - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer( - input: DataFrame, - nodeFunction: LocalNode => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - input :: Nil, - nodes => nodeFunction(nodes.head), - expectedAnswer, - sortAnswers) - } - - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param left the left input data to be used. - * @param right the right input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer2( - left: DataFrame, - right: DataFrame, - nodeFunction: (LocalNode, LocalNode) => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - left :: right :: Nil, - nodes => nodeFunction(nodes(0), nodes(1)), - expectedAnswer, - sortAnswers) - } + protected val conf: SQLConf = new SQLConf + protected val kvIntAttributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) + protected val joinNameAttributes = Seq( + AttributeReference("id1", IntegerType)(), + AttributeReference("name", StringType)()) + protected val joinNicknameAttributes = Seq( + AttributeReference("id2", IntegerType)(), + AttributeReference("nickname", StringType)()) /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Wrap a function processing two [[LocalNode]]s such that: + * (1) all input rows are automatically converted to unsafe rows + * (2) all output rows are automatically converted back to safe rows */ - protected def doCheckAnswer( - input: Seq[DataFrame], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - LocalNodeTest.checkAnswer( - input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match { - case Some(errorMessage) => fail(errorMessage) - case None => + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) } } - protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { - new SeqScanNode( - conf, - df.queryExecution.sparkPlan.output, - df.queryExecution.toRdd.map(_.copy()).collect()) - } - -} - -/** - * Helper methods for writing tests of individual local physical operators. - */ -object LocalNodeTest { - /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. */ - def checkAnswer( - input: Seq[SeqScanNode], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { - - val outputNode = nodeFunction(input) - - val outputResult: Seq[Row] = try { - outputNode.collect() - } catch { - case NonFatal(e) => - val errorMessage = - s""" - | Exception thrown while executing local plan: - | $outputNode - | == Exception == - | $e - | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage => - s""" - | Results do not match for local plan: - | $outputNode - | $errorMessage - """.stripMargin + protected def resolveExpressions(outputNode: LocalNode): LocalNode = { + outputNode transform { + case node: LocalNode => + val inputMap = node.output.map { a => (a.name, a) }.toMap + node transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } } } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index b1ef26ba82f1..40299d9d5ee3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -18,222 +18,128 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + class NestedLoopJoinNodeSuite extends LocalNodeTest { - import testImplicits._ - - private def joinSuite( - suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = { - test(s"$suiteName: left outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("n") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - upperCaseData.col("N") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect()) + // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(unsafeAndCodegen, buildSide, joinType) } } + } - test(s"$suiteName: right outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect()) + /** + * Test outer nested loop joins with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide, + joinType: JoinType): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest( + joinType: JoinType, + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)]): Unit = { + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val cond = 'id1 === 'id2 + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions( + new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput.toSet === expectedOutput.toSet) } - test(s"$suiteName: full outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect()) - } + test(s"$testNamePrefix: empty") { + runTest(joinType, Array.empty, Array.empty) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray + runTest(joinType, someData, Array.empty) + runTest(joinType, Array.empty, someData) + runTest(joinType, someData, someIrrelevantData) + runTest(joinType, someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(joinType, someData, someOtherData) + runTest(joinType, someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } + runTest(joinType, someData, someSuperRelevantData) + runTest(joinType, someSuperRelevantData, someData) + } + } + + /** + * Helper method to generate the expected output of a test based on the join type. + */ + private def generateExpectedOutput( + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)], + joinType: JoinType): Array[(Int, String, Int, String)] = { + joinType match { + case LeftOuter => + val rightInputMap = rightInput.toMap + leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + + case RightOuter => + val leftInputMap = leftInput.toMap + rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + + case FullOuter => + val leftInputMap = leftInput.toMap + val rightInputMap = rightInput.toMap + val leftOutput = leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + val rightOutput = rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + (leftOutput ++ rightOutput).distinct + + case other => + throw new IllegalArgumentException(s"Join type $other is not applicable") } } - joinSuite( - "general-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "general-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "tungsten-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") - joinSuite( - "tungsten-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index 38e0a230c46d..02ecb23d34b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -17,28 +17,33 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.types.{IntegerType, StringType} -class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val output = testData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - testData, - node => ProjectNode(conf, columns, node), - testData.select("value", "key").collect() - ) +class ProjectNodeSuite extends LocalNodeTest { + private val pieAttributes = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("age", IntegerType)(), + AttributeReference("name", StringType)()) + + private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { + val inputNode = new DummyNode(pieAttributes, inputData) + val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) + val projectNode = new ProjectNode(conf, columns, inputNode) + val expectedOutput = inputData.map { case (id, age, name) => (id, name) } + val actualOutput = projectNode.collect().map { case row => + (row.getInt(0), row.getString(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val output = emptyTestData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - emptyTestData, - node => ProjectNode(conf, columns, node), - emptyTestData.select("value", "key").collect() - ) + testProject() + } + + test("basic") { + testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala index 87a7da453999..a3e83bbd5145 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -17,21 +17,32 @@ package org.apache.spark.sql.execution.local -class SampleNodeSuite extends LocalNodeTest { +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + - import testImplicits._ +class SampleNodeSuite extends LocalNodeTest { private def testSample(withReplacement: Boolean): Unit = { - test(s"withReplacement: $withReplacement") { - val seed = 0L - val input = sqlContext.sparkContext. - parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition - toDF("key", "value") - checkAnswer( - input, - node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), - input.sample(withReplacement, 0.3, seed).collect() - ) + val seed = 0L + val lowerb = 0.0 + val upperb = 0.3 + val maybeOut = if (withReplacement) "" else "out" + test(s"with$maybeOut replacement") { + val inputData = (1 to 1000).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) + val sampler = + if (withReplacement) { + new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) + } else { + new BernoulliCellSampler[(Int, Int)](lowerb, upperb) + } + sampler.setSeed(seed) + val expectedOutput = sampler.sample(inputData.iterator).toArray + val actualOutput = sampleNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala index ff28b24eeff1..42ebc7bfcaad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -17,38 +17,34 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} +import scala.util.Random -class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SortOrder - import testImplicits._ - private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - sortOrder - } +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { - private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { - val testCaseName = if (desc) "desc" else "asc" - test(testCaseName) { - val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val sortColumn = if (desc) input.col("key").desc else input.col("key") - checkAnswer( - input, - node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), - input.sort(sortColumn).limit(5).collect() - ) + private def testTakeOrderedAndProject(desc: Boolean): Unit = { + val limit = 10 + val ascOrDesc = if (desc) "desc" else "asc" + test(ascOrDesc) { + val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val firstColumn = inputNode.output(0) + val sortDirection = if (desc) Descending else Ascending + val sortOrder = SortOrder(firstColumn, sortDirection) + val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( + conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) + val expectedOutput = inputData + .map { case (k, _) => k } + .sortBy { k => k * (if (desc) -1 else 1) } + .take(limit) + val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } + assert(actualOutput === expectedOutput) } } - testTakeOrderedAndProjectNode(desc = false) - testTakeOrderedAndProjectNode(desc = true) + testTakeOrderedAndProject(desc = false) + testTakeOrderedAndProject(desc = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index eedd7320900f..666b0235c061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -17,36 +17,39 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { +class UnionNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer2( - testData, - testData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - testData.unionAll(testData).collect() - ) + private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { + val inputNodes = inputData.map { data => + new DummyNode(kvIntAttributes, data) + } + val unionNode = new UnionNode(conf, inputNodes) + val expectedOutput = inputData.flatten + val actualOutput = unionNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer2( - emptyTestData, - emptyTestData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - emptyTestData.unionAll(emptyTestData).collect() - ) + testUnion(Seq(Array.empty)) + testUnion(Seq(Array.empty, Array.empty)) + } + + test("self") { + val data = (1 to 100).map { i => (i, i) }.toArray + testUnion(Seq(data)) + testUnion(Seq(data, data)) + testUnion(Seq(data, data, data)) } - test("complicated union") { - val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData, - emptyTestData, emptyTestData, testData, emptyTestData) - doCheckAnswer( - dfs, - nodes => UnionNode(conf, nodes), - dfs.reduce(_.unionAll(_)).collect() - ) + test("basic") { + val zero = Array.empty[(Int, Int)] + val one = (1 to 100).map { i => (i, i) }.toArray + val two = (50 to 150).map { i => (i, i) }.toArray + val three = (800 to 900).map { i => (i, i) }.toArray + testUnion(Seq(zero, one, two, three)) } } From 64c29afcb787d9f176a197c25314295108ba0471 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Tue, 15 Sep 2015 19:41:38 -0700 Subject: [PATCH 0644/1824] [SPARK-9078] [SQL] Allow jdbc dialects to override the query used to check the table. Current implementation uses query with a LIMIT clause to find if table already exists. This syntax works only in some database systems. This patch changes the default query to the one that is likely to work on most databases, and adds a new method to the JdbcDialect abstract class to allow dialects to override the default query. I looked at using the JDBC meta data calls, it turns out there is no common way to find the current schema, catalog..etc. There is a new method Connection.getSchema() , but that is available only starting jdk1.7 , and existing jdbc drivers may not have implemented it. Other option was to use jdbc escape syntax clause for LIMIT, not sure on how well this supported in all the databases also. After looking at all the jdbc metadata options my conclusion was most common way is to use the simple select query with 'where 1 =0' , and allow dialects to customize as needed Author: sureshthalamati Closes #8676 from sureshthalamati/table_exists_spark-9078. --- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../datasources/jdbc/JdbcUtils.scala | 9 ++++++--- .../apache/spark/sql/jdbc/JdbcDialects.scala | 20 +++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 14 +++++++++++++ 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b2a66dd417b4..745bb4ec9cf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { val conn = JdbcUtils.createConnection(url, props) try { - var tableExists = JdbcUtils.tableExists(conn, table) + var tableExists = JdbcUtils.tableExists(conn, url, table) if (mode == SaveMode.Ignore && tableExists) { return diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 26788b2a4fd6..f89d55b20e21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -42,10 +42,13 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, table: String): Boolean = { + def tableExists(conn: Connection, url: String, table: String): Boolean = { + val dialect = JdbcDialects.get(url) + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + // SQL database systems using JDBC meta data calls, considering "table" could also include + // the database name. Query used to find table exists can be overriden by the dialects. + Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c6d05c9b83b9..68ebaaca6c53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -88,6 +88,17 @@ abstract class JdbcDialect { def quoteIdentifier(colName: String): String = { s""""$colName"""" } + + /** + * Get the SQL query that should be used to find if the given table exists. Dialects can + * override this method to return a query that works best in a particular database. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + */ + def getTableExistsQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + } /** @@ -198,6 +209,11 @@ case object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) case _ => None } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + } /** @@ -222,6 +238,10 @@ case object MySQLDialect extends JdbcDialect { override def quoteIdentifier(colName: String): String = { s"`$colName`" } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index ed710689cc67..5ab9381de4d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -450,4 +450,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") } + + test("table exists query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val table = "weblogs" + val defaultQuery = s"SELECT * FROM $table WHERE 1=0" + val limitQuery = s"SELECT 1 FROM $table LIMIT 1" + assert(MySQL.getTableExistsQuery(table) == limitQuery) + assert(Postgres.getTableExistsQuery(table) == limitQuery) + assert(db2.getTableExistsQuery(table) == defaultQuery) + assert(h2.getTableExistsQuery(table) == defaultQuery) + } } From b921fe4dc0442aa133ab7d55fba24bc798d59aa2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 15 Sep 2015 19:43:26 -0700 Subject: [PATCH 0645/1824] [SPARK-10595] [ML] [MLLIB] [DOCS] Various ML guide cleanups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Various ML guide cleanups. * ml-guide.md: Make it easier to access the algorithm-specific guides. * LDA user guide: EM often begins with useless topics, but running longer generally improves them dramatically. E.g., 10 iterations on a Wikipedia dataset produces useless topics, but 50 iterations produces very meaningful topics. * mllib-feature-extraction.html#elementwiseproduct: “w†parameter should be “scalingVec†* Clean up Binarizer user guide a little. * Document in Pipeline that users should not put an instance into the Pipeline in more than 1 place. * spark.ml Word2Vec user guide: clean up grammar/writing * Chi Sq Feature Selector docs: Improve text in doc. CC: mengxr feynmanliang Author: Joseph K. Bradley Closes #8752 from jkbradley/mlguide-fixes-1.5. --- docs/ml-features.md | 34 +++++++++++++++++--- docs/ml-guide.md | 31 ++++++++++++------- docs/mllib-clustering.md | 4 +++ docs/mllib-feature-extraction.md | 53 +++++++++++++++++++++----------- docs/mllib-guide.md | 4 +-- 5 files changed, 91 insertions(+), 35 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a414c21b5c28..b70da4ac6384 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -123,12 +123,21 @@ for features_label in rescaledData.select("features", "label").take(3): ## Word2Vec -`Word2Vec` is an `Estimator` which takes sequences of words that represents documents and trains a `Word2VecModel`. The model is a `Map(String, Vector)` essentially, which maps each word to an unique fix-sized vector. The `Word2VecModel` transforms each documents into a vector using the average of all words in the document, which aims to other computations of documents such as similarity calculation consequencely. Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more details on Word2Vec. +`Word2Vec` is an `Estimator` which takes sequences of words representing documents and trains a +`Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` +transforms each document into a vector using the average of all words in the document; this vector +can then be used for as features for prediction, document similarity calculations, etc. +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more +details. -Word2Vec is implemented in [Word2Vec](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec). In the following code segment, we start with a set of documents, each of them is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. +In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm.
    + +Refer to the [Word2Vec Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec) +for more details on the API. + {% highlight scala %} import org.apache.spark.ml.feature.Word2Vec @@ -152,6 +161,10 @@ result.select("result").take(3).foreach(println)
    + +Refer to the [Word2Vec Java docs](api/java/org/apache/spark/ml/feature/Word2Vec.html) +for more details on the API. + {% highlight java %} import java.util.Arrays; @@ -192,6 +205,10 @@ for (Row r: result.select("result").take(3)) {
    + +Refer to the [Word2Vec Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Word2Vec) +for more details on the API. + {% highlight python %} from pyspark.ml.feature import Word2Vec @@ -621,12 +638,15 @@ for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): ## Binarizer -Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. +Binarization is the process of thresholding numerical features to binary (0/1) features. -A simple [Binarizer](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) class provides this functionality. Besides the common parameters of `inputCol` and `outputCol`, `Binarizer` has the parameter `threshold` used for binarizing continuous numerical features. The features greater than the threshold, will be binarized to 1.0. The features equal to or less than the threshold, will be binarized to 0.0. The example below shows how to binarize numerical features. +`Binarizer` takes the common parameters `inputCol` and `outputCol`, as well as the `threshold` for binarization. Feature values greater than the threshold are binarized to 1.0; values equal to or less than the threshold are binarized to 0.0.
    + +Refer to the [Binarizer API doc](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details. + {% highlight scala %} import org.apache.spark.ml.feature.Binarizer import org.apache.spark.sql.DataFrame @@ -650,6 +670,9 @@ binarizedFeatures.collect().foreach(println)
    + +Refer to the [Binarizer API doc](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details. + {% highlight java %} import java.util.Arrays; @@ -687,6 +710,9 @@ for (Row r : binarizedFeatures.collect()) {
    + +Refer to the [Binarizer API doc](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details. + {% highlight python %} from pyspark.ml.feature import Binarizer diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 78c93a95c780..c5d7f990021f 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -32,7 +32,21 @@ See the [algorithm guides](#algorithm-guides) section below for guides on sub-pa * This will become a table of contents (this text will be scraped). {:toc} -# Main concepts +# Algorithm guides + +We provide several algorithm guides specific to the Pipelines API. +Several of these algorithms, such as certain feature transformers, are not in the `spark.mllib` API. +Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., random forests +provide class probabilities, and linear models provide model summaries. + +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision Trees for classification and regression](ml-decision-tree.html) +* [Ensembles](ml-ensembles.html) +* [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) + + +# Main concepts in Pipelines Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. @@ -166,6 +180,11 @@ compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. + ## Parameters Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. @@ -184,16 +203,6 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Algorithm guides - -There are now several algorithms in the Pipelines API which are not in the `spark.mllib` API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. - -* [Feature extraction, transformation, and selection](ml-features.html) -* [Decision Trees for classification and regression](ml-decision-tree.html) -* [Ensembles](ml-ensembles.html) -* [Linear methods with elastic net regularization](ml-linear-methods.html) -* [Multilayer perceptron classifier](ml-ann.html) - # Code examples This section gives code examples illustrating the functionality discussed above. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 3fb35d3c50b0..c2711cf82deb 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -507,6 +507,10 @@ must also be $> 1.0$. Providing `Vector(-1)` results in default behavior $> 1.0$. Providing `-1` results in defaulting to a value of $0.1 + 1$. * `maxIterations`: The maximum number of EM iterations. +*Note*: It is important to do enough iterations. In early iterations, EM often has useless topics, +but those topics improve dramatically after more iterations. Using at least 20 and possibly +50-100 iterations is often reasonable, depending on your dataset. + `EMLDAOptimizer` produces a `DistributedLDAModel`, which stores not only the inferred topics but also the full training corpus and topic distributions for each document in the training corpus. A diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index de86aba2ae62..7e417ed5f37a 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -380,35 +380,43 @@ data2 = labels.zip(normalizer2.transform(features))
    -## Feature selection -[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. +## ChiSqSelector -### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) tries to identify relevant +features for use in model construction. It reduces the size of the feature space, which can improve +both speed and statistical learning behavior. -#### Model Fitting +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements +Chi-Squared feature selection. It operates on labeled data with categorical features. +`ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, +and then filters (selects) the top features which the class label depends on the most. +This is akin to yielding the features with the most predictive power. -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the -following parameters in the constructor: +The number of features to select can be tuned using a held-out validation set. -* `numTopFeatures` number of top features that the selector will select (filter). +### Model Fitting -We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in -`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then -return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that +the selector will select. -This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) -which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes +an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then +returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +The `ChiSqSelectorModel` can be applied either to a `Vector` to produce a reduced `Vector`, or to an `RDD[Vector]` to produce a reduced `RDD[Vector]`. Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). -#### Example +### Example The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
    -
    +
    + +Refer to the [`ChiSqSelector` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) +for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors @@ -434,7 +442,11 @@ val filteredData = discretizedData.map { lp => {% endhighlight %}
    -
    +
    + +Refer to the [`ChiSqSelector` Java docs](api/java/org/apache/spark/mllib/feature/ChiSqSelector.html) +for details on the API. + {% highlight java %} import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; @@ -486,7 +498,12 @@ sc.stop(); ## ElementwiseProduct -ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. +`ElementwiseProduct` multiplies each input vector by a provided "weight" vector, using element-wise +multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This +represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) +between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. +Qu8T948*1# +Denoting the `scalingVec` as "`w`," this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ @@ -506,7 +523,7 @@ v_N [`ElementwiseProduct`](api/scala/index.html#org.apache.spark.mllib.feature.ElementwiseProduct) has the following parameter in the constructor: -* `w`: the transforming vector. +* `scalingVec`: the transforming vector. `ElementwiseProduct` implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) which can apply the weighting on a `Vector` to produce a transformed `Vector` or on an `RDD[Vector]` to produce a transformed `RDD[Vector]`. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 257f7cc7603f..91e50ccfecec 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -13,9 +13,9 @@ primitives and higher-level pipeline APIs. It divides into two packages: -* [`spark.mllib`](mllib-guide.html#mllib-types-algorithms-and-utilities) contains the original API +* [`spark.mllib`](mllib-guide.html#data-types-algorithms-and-utilities) contains the original API built on top of [RDDs](programming-guide.html#resilient-distributed-datasets-rdds). -* [`spark.ml`](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) provides higher-level API +* [`spark.ml`](ml-guide.html) provides higher-level API built on top of [DataFrames](sql-programming-guide.html#dataframes) for constructing ML pipelines. Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. From 95b6a8103fb527f501ca26b1d6e3a5859970a1e2 Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Tue, 15 Sep 2015 23:25:51 -0700 Subject: [PATCH 0646/1824] [SPARK-10516] [ MLLIB] Added values property in DenseVector Author: Vinod K C Closes #8682 from vinodkc/fix_SPARK-10516. --- python/pyspark/mllib/linalg/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 380f86e9b44f..4829acb16ed8 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -399,6 +399,10 @@ def squared_distance(self, other): def toArray(self): return self.array + @property + def values(self): + return self.array + def __getitem__(self, item): return self.array[item] From 1894653edce718e874d1ddc9ba442bce43cbc082 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Wed, 16 Sep 2015 10:47:30 +0100 Subject: [PATCH 0647/1824] [SPARK-10511] [BUILD] Reset git repository before packaging source distro The calculation of Spark version is downloading Scala and Zinc in the build directory which is inflating the size of the source distribution. Reseting the repo before packaging the source distribution fix this issue. Author: Luciano Resende Closes #8774 from lresende/spark-10511. --- dev/create-release/release-build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index d0b3a54dde1d..9dac43ce5442 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -99,6 +99,7 @@ fi DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" USER_HOST="$ASF_USERNAME@people.apache.org" +git clean -d -f -x rm .gitignore rm -rf .git cd .. From d9b7f3e4dbceb91ea4d1a1fed3ab847335f8588b Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 16 Sep 2015 04:34:14 -0700 Subject: [PATCH 0648/1824] [SPARK-10276] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.recommendation Author: Yu ISHIKAWA Closes #8677 from yu-iskw/SPARK-10276. --- python/pyspark/mllib/recommendation.py | 36 +++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 506ca2151cce..95047b5b7b4b 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -18,7 +18,7 @@ import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.util import JavaLoader, JavaSaveable @@ -36,6 +36,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])): (1, 2, 5.0) >>> (r[0], r[1], r[2]) (1, 2, 5.0) + + .. versionadded:: 1.2.0 """ def __reduce__(self): @@ -111,13 +113,17 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 0.9.0 """ + @since("0.9.0") def predict(self, user, product): """ Predicts rating for the given user and product. """ return self._java_model.predict(int(user), int(product)) + @since("0.9.0") def predictAll(self, user_product): """ Returns a list of predicted ratings for input user and product pairs. @@ -128,6 +134,7 @@ def predictAll(self, user_product): user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) + @since("1.2.0") def userFeatures(self): """ Returns a paired RDD, where the first element is the user and the @@ -135,6 +142,7 @@ def userFeatures(self): """ return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.2.0") def productFeatures(self): """ Returns a paired RDD, where the first element is the product and the @@ -142,6 +150,7 @@ def productFeatures(self): """ return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.4.0") def recommendUsers(self, product, num): """ Recommends the top "num" number of users for a given product and returns a list @@ -149,6 +158,7 @@ def recommendUsers(self, product, num): """ return list(self.call("recommendUsers", product, num)) + @since("1.4.0") def recommendProducts(self, user, num): """ Recommends the top "num" number of products for a given user and returns a list @@ -157,17 +167,25 @@ def recommendProducts(self, user, num): return list(self.call("recommendProducts", user, num)) @property + @since("1.4.0") def rank(self): + """Rank for the features in this model""" return self.call("rank") @classmethod + @since("1.3.1") def load(cls, sc, path): + """Load a model from the given path""" model = cls._load_java(sc, path) wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) return MatrixFactorizationModel(wrapper) class ALS(object): + """Alternating Least Squares matrix factorization + + .. versionadded:: 0.9.0 + """ @classmethod def _prepare(cls, ratings): @@ -188,15 +206,31 @@ def _prepare(cls, ratings): return ratings @classmethod + @since("0.9.0") def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of ratings given by users to some products, + in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + product of two lower-rank matrices of a given rank (number of features). To solve for these + features, we run a given number of iterations of ALS. This is done using a level of + parallelism given by `blocks`. + """ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, nonnegative, seed) return MatrixFactorizationModel(model) @classmethod + @since("0.9.0") def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of 'implicit preferences' given by users + to some products, in the form of (userID, productID, preference) pairs. We approximate the + ratings matrix as the product of two lower-rank matrices of a given rank (number of + features). To solve for these features, we run a given number of iterations of ALS. + This is done using a level of parallelism given by `blocks`. + """ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, alpha, nonnegative, seed) return MatrixFactorizationModel(model) From 5dbaf3d3911bbfa003bc75459aaad66b4f6e0c67 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 16 Sep 2015 19:19:23 +0100 Subject: [PATCH 0649/1824] [SPARK-10589] [WEBUI] Add defense against external site framing Set `X-Frame-Options: SAMEORIGIN` to protect against frame-related vulnerability Author: Sean Owen Closes #8745 from srowen/SPARK-10589. --- .../spark/deploy/worker/ui/WorkerWebUI.scala | 7 ++++--- .../org/apache/spark/metrics/MetricsSystem.scala | 2 +- .../spark/metrics/sink/MetricsServlet.scala | 6 +++--- .../scala/org/apache/spark/ui/JettyUtils.scala | 16 ++++++++++++++-- .../main/scala/org/apache/spark/ui/WebUI.scala | 4 ++-- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 709a27233598..1a0598e50dcf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,9 +20,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker -import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.RpcUtils @@ -49,7 +48,9 @@ class WorkerWebUI( attachPage(new WorkerPage(this)) attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) attachHandler(createServletHandler("/log", - (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr)) + (request: HttpServletRequest) => logPage.renderLog(request), + worker.securityMgr, + worker.conf)) } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 4517f465ebd3..48afe3ae3511 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -88,7 +88,7 @@ private[spark] class MetricsSystem private ( */ def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") - metricsServlet.map(_.getHandlers).getOrElse(Array()) + metricsServlet.map(_.getHandlers(conf)).getOrElse(Array()) } metricsConfig.initialize() diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 0c2e212a3307..4193e1d21d3c 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -27,7 +27,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.SecurityManager +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( @@ -49,10 +49,10 @@ private[spark] class MetricsServlet( val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers: Array[ServletContextHandler] = { + def getHandlers(conf: SparkConf): Array[ServletContextHandler] = { Array[ServletContextHandler]( createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr, conf) ) } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 779c0ba08359..b796a44fe01a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -59,7 +59,17 @@ private[spark] object JettyUtils extends Logging { def createServlet[T <% AnyRef]( servletParams: ServletParams[T], - securityMgr: SecurityManager): HttpServlet = { + securityMgr: SecurityManager, + conf: SparkConf): HttpServlet = { + + // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options + // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the + // same origin, but allow framing for a specific named URI. + // Example: spark.ui.allowFramingFrom = https://example.com/ + val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom") + val xFrameOptionsValue = + allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN") + new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { try { @@ -68,6 +78,7 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.setHeader("X-Frame-Options", xFrameOptionsValue) // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) // scalastyle:on println @@ -97,8 +108,9 @@ private[spark] object JettyUtils extends Logging { path: String, servletParams: ServletParams[T], securityMgr: SecurityManager, + conf: SparkConf, basePath: String = ""): ServletContextHandler = { - createServletHandler(path, createServlet(servletParams, securityMgr), basePath) + createServletHandler(path, createServlet(servletParams, securityMgr, conf), basePath) } /** Create a context handler that responds to a request with the given path prefix */ diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 61449847add3..81a121fd441b 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -76,9 +76,9 @@ private[spark] abstract class WebUI( def attachPage(page: WebUIPage) { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, - (request: HttpServletRequest) => page.render(request), securityManager, basePath) + (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", - (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath) + (request: HttpServletRequest) => page.renderJson(request), securityManager, conf, basePath) attachHandler(renderHandler) attachHandler(renderJsonHandler) pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) From 896edb51ab7a88bbb31259e565311a9be6f2ca6d Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 16 Sep 2015 13:20:39 -0700 Subject: [PATCH 0650/1824] [SPARK-10050] [SPARKR] Support collecting data of MapType in DataFrame. 1. Support collecting data of MapType from DataFrame. 2. Support data of MapType in createDataFrame. Author: Sun Rui Closes #8711 from sun-rui/SPARK-10050. --- R/pkg/R/SQLContext.R | 5 +- R/pkg/R/deserialize.R | 14 +++++ R/pkg/R/schema.R | 34 ++++++++--- R/pkg/inst/tests/test_sparkSQL.R | 56 +++++++++++++++---- .../scala/org/apache/spark/api/r/SerDe.scala | 31 ++++++++++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 6 ++ 6 files changed, 123 insertions(+), 23 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 4ac057d0f2d8..1c58fd96d750 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -41,10 +41,7 @@ infer_type <- function(x) { if (type == "map") { stopifnot(length(x) > 0) key <- ls(x)[[1]] - list(type = "map", - keyType = "string", - valueType = infer_type(get(key, x)), - valueContainsNull = TRUE) + paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) names <- names(x) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d1858ec227b5..ce88d0b071b7 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -50,6 +50,7 @@ readTypedObject <- function(con, type) { "t" = readTime(con), "a" = readArray(con), "l" = readList(con), + "e" = readEnv(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -121,6 +122,19 @@ readList <- function(con) { } } +readEnv <- function(con) { + env <- new.env() + len <- readInt(con) + if (len > 0) { + for (i in 1:len) { + key <- readString(con) + value <- readObject(con) + env[[key]] <- value + } + } + env +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 62d4f73878d2..8df1563f8ebc 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -131,13 +131,33 @@ checkType <- function(type) { if (type %in% primtiveTypes) { return() } else { - m <- regexec("^array<(.*)>$", type) - matchedStrings <- regmatches(type, m) - if (length(matchedStrings[[1]]) >= 2) { - elemType <- matchedStrings[[1]][2] - checkType(elemType) - return() - } + # Check complex types + firstChar <- substr(type, 1, 1) + switch (firstChar, + a = { + # Array type + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + }, + m = { + # Map type + m <- regexec("^map<(.*),(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 3) { + keyType <- matchedStrings[[1]][2] + if (keyType != "string" && keyType != "character") { + stop("Key type in a map must be string or character") + } + valueType <- matchedStrings[[1]][3] + checkType(valueType) + return() + } + }) } stop(paste("Unsupported type for Dataframe:", type)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 98d4402d368e..e159a6958427 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -57,7 +57,7 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) -test_that("infer types", { +test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") expect_equal(infer_type("abc"), "string") @@ -72,9 +72,9 @@ test_that("infer types", { checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) - expect_equal(infer_type(e), - list(type = "map", keyType = "string", valueType = "integer", - valueContainsNull = TRUE)) + expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") }) test_that("structType and structField", { @@ -242,7 +242,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and map", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -253,21 +253,35 @@ test_that("create DataFrame with nested array and struct", { # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) + # ArrayType and MapType + e <- new.env() + assign("n", 3L, envir = e) - # ArrayType only for now - l <- list(as.list(1:10), list("a", "b")) - df <- createDataFrame(sqlContext, list(l), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"))) + l <- list(as.list(1:10), list("a", "b"), e) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"))) expect_equal(count(df), 1) ldf <- collect(df) - expect_equal(names(ldf), c("a", "b")) + expect_equal(names(ldf), c("a", "b", "c")) expect_equal(ldf[1, 1][[1]], l[[1]]) expect_equal(ldf[1, 2][[1]], l[[2]]) + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) }) +# For test map type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + test_that("Collect DataFrame with complex types", { - # only ArrayType now - # TODO: tests for StructType and MapType after they are supported + # ArrayType df <- jsonFile(sqlContext, complexTypeJsonPath) ldf <- collect(df) @@ -277,6 +291,24 @@ test_that("Collect DataFrame with complex types", { expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # TODO: tests for StructType after it is supported }) test_that("jsonFile() on a local file returns a DataFrame", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 3c92bb7a1c73..0c78613e406e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -209,11 +209,23 @@ private[spark] object SerDe { case "array" => dos.writeByte('a') // Array of objects case "list" => dos.writeByte('l') + case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } } + private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value) + } + def writeObject(dos: DataOutputStream, obj: Object): Unit = { if (obj == null) { writeType(dos, "void") @@ -306,6 +318,25 @@ private[spark] object SerDe { writeInt(dos, v.length) v.foreach(elem => writeObject(dos, elem)) + // Handle map + case v: java.util.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + val iter = v.entrySet.iterator + while(iter.hasNext) { + val entry = iter.next + val key = entry.getKey + val value = entry.getValue + + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case v: scala.collection.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + v.foreach { case (key, value) => + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case _ => writeType(dos, "jobj") writeJObj(dos, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d4b834adb6e3..f45d119c8cfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -64,6 +64,12 @@ private[r] object SQLUtils { case r"\Aarray<(.*)${elemType}>\Z" => { org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) } + case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => { + if (keyType != "string" && keyType != "character") { + throw new IllegalArgumentException("Key type of a map must be string or character") + } + org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) + } case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } From d39f15ea2b8bed5342d2f8e3c1936f915c470783 Mon Sep 17 00:00:00 2001 From: Kevin Cox Date: Wed, 16 Sep 2015 15:30:17 -0700 Subject: [PATCH 0651/1824] [SPARK-9794] [SQL] Fix datetime parsing in SparkSQL. This fixes https://issues.apache.org/jira/browse/SPARK-9794 by using a real ISO8601 parser. (courtesy of the xml component of the standard java library) cc: angelini Author: Kevin Cox Closes #8396 from kevincox/kevincox-sql-time-parsing. --- .../sql/catalyst/util/DateTimeUtils.scala | 27 ++++++---------- .../catalyst/util/DateTimeUtilsSuite.scala | 32 +++++++++++++++++++ 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 687ca000d12b..400c4327be1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{TimeZone, Calendar} +import javax.xml.bind.DatatypeConverter; import org.apache.spark.unsafe.types.UTF8String @@ -109,30 +110,22 @@ object DateTimeUtils { } def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { + var indexOfGMT = s.indexOf("GMT"); + if (indexOfGMT != -1) { + // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00) + val s0 = s.substring(0, indexOfGMT) + val s1 = s.substring(indexOfGMT + 3) + // Mapped to 2000-01-01T00:00+01:00 + stringToTime(s0 + s1) + } else if (!s.contains('T')) { // JDBC escape string if (s.contains(' ')) { Timestamp.valueOf(s) } else { Date.valueOf(s) } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) + DatatypeConverter.parseDateTime(s).getTime() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 6b9a11f0ff74..46335941b62d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -136,6 +136,38 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) } + test("string to time") { + // Tests with UTC. + var c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-00:00") === c.getTime()) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30T10:00:00Z") === c.getTime()) + + // Tests with set time zone. + c.setTimeZone(TimeZone.getTimeZone("GMT-04:00")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00-04:00") === c.getTime()) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-04:00") === c.getTime()) + + // Tests with local time zone. + c.setTimeZone(TimeZone.getDefault()) + c.set(Calendar.MILLISECOND, 0) + + c.set(2000, 11, 30, 0, 0, 0) + assert(stringToTime("2000-12-30") === new Date(c.getTimeInMillis())) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30 10:00:00") === new Timestamp(c.getTimeInMillis())) + } + test("string to timestamp") { var c = Calendar.getInstance() c.set(1969, 11, 31, 16, 0, 0) From 49c649fa0b6affed108dbae85373b4b7247b338c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Sep 2015 15:32:01 -0700 Subject: [PATCH 0652/1824] Tiny style fix for d39f15ea2b8bed5342d2f8e3c1936f915c470783. --- .../org/apache/spark/sql/catalyst/util/DateTimeUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 400c4327be1c..781ed1688a32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{TimeZone, Calendar} -import javax.xml.bind.DatatypeConverter; +import javax.xml.bind.DatatypeConverter import org.apache.spark.unsafe.types.UTF8String From 69c9830d288d5b8d7f0abe7c8a65a4c966580a49 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 17 Sep 2015 00:48:57 -0700 Subject: [PATCH 0653/1824] [MINOR] [CORE] Fixes minor variable name typo Author: Cheng Lian Closes #8784 from liancheng/typo-fix. --- .../apache/spark/serializer/GenericAvroSerializerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index bc9f3708ed69..87f25e7245e1 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -76,9 +76,9 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { test("caches previously seen schemas") { val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) - val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) assert(compressedSchema.eq(genericSer.compress(schema))) - assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) } } From c633ed3260140f1288f326acc4d7a10dcd2e27d5 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:43:59 -0700 Subject: [PATCH 0654/1824] [SPARK-10284] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.tuning Author: Yu ISHIKAWA Closes #8694 from yu-iskw/SPARK-10284. --- python/pyspark/ml/tuning.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index cae778869e9c..ab5621f45c72 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,6 +18,7 @@ import itertools import numpy as np +from pyspark import since from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model from pyspark.ml.util import keyword_only @@ -47,11 +48,14 @@ class ParamGridBuilder(object): True >>> all([m in expected for m in output]) True + + .. versionadded:: 1.4.0 """ def __init__(self): self._param_grid = {} + @since("1.4.0") def addGrid(self, param, values): """ Sets the given parameters in this grid to fixed values. @@ -60,6 +64,7 @@ def addGrid(self, param, values): return self + @since("1.4.0") def baseOn(self, *args): """ Sets the given parameters in this grid to fixed values. @@ -73,6 +78,7 @@ def baseOn(self, *args): return self + @since("1.4.0") def build(self): """ Builds and returns all combinations of parameters specified @@ -104,6 +110,8 @@ class CrossValidator(Estimator): >>> cvModel = cv.fit(dataset) >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -142,6 +150,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF self._set(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): @@ -150,6 +159,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, num kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. @@ -157,12 +167,14 @@ def setEstimator(self, value): self._paramMap[self.estimator] = value return self + @since("1.4.0") def getEstimator(self): """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator) + @since("1.4.0") def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. @@ -170,12 +182,14 @@ def setEstimatorParamMaps(self, value): self._paramMap[self.estimatorParamMaps] = value return self + @since("1.4.0") def getEstimatorParamMaps(self): """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps) + @since("1.4.0") def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. @@ -183,12 +197,14 @@ def setEvaluator(self, value): self._paramMap[self.evaluator] = value return self + @since("1.4.0") def getEvaluator(self): """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator) + @since("1.4.0") def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. @@ -196,6 +212,7 @@ def setNumFolds(self, value): self._paramMap[self.numFolds] = value return self + @since("1.4.0") def getNumFolds(self): """ Gets the value of numFolds or its default value. @@ -231,7 +248,15 @@ def _fit(self, dataset): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) + @since("1.4.0") def copy(self, extra=None): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies creates a deep copy of + the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ if extra is None: extra = dict() newCV = Params.copy(self, extra) @@ -246,6 +271,8 @@ def copy(self, extra=None): class CrossValidatorModel(Model): """ Model from k-fold cross validation. + + .. versionadded:: 1.4.0 """ def __init__(self, bestModel): @@ -256,6 +283,7 @@ def __init__(self, bestModel): def _transform(self, dataset): return self.bestModel.transform(dataset) + @since("1.4.0") def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid From 29bf8aa5a51fdd8c2600533297f991e14fa27c03 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:45:20 -0700 Subject: [PATCH 0655/1824] [SPARK-10283] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.regression Author: Yu ISHIKAWA Closes #8693 from yu-iskw/SPARK-10283. --- python/pyspark/ml/regression.py | 65 +++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a9503608b7f2..21d454f9003b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -62,6 +63,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.4.0 """ @keyword_only @@ -81,6 +84,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, standardization=True): @@ -96,13 +100,31 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) + @since("1.4.0") + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + @since("1.4.0") + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + class LinearRegressionModel(JavaModel): """ Model fitted by LinearRegression. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. @@ -110,6 +132,7 @@ def weights(self): return self._call_java("weights") @property + @since("1.4.0") def intercept(self): """ Model intercept. @@ -162,6 +185,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -193,6 +218,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -209,6 +235,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -216,6 +243,7 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. @@ -225,13 +253,19 @@ def getImpurity(self): @inherit_doc class DecisionTreeModel(JavaModel): + """Abstraction for Decision Tree models. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def numNodes(self): """Return number of nodes of the decision tree.""" return self._call_java("numNodes") @property + @since("1.5.0") def depth(self): """Return depth of the decision tree.""" return self._call_java("depth") @@ -242,8 +276,13 @@ def __repr__(self): @inherit_doc class TreeEnsembleModels(JavaModel): + """Represents a tree ensemble model. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) @@ -256,6 +295,8 @@ def __repr__(self): class DecisionTreeRegressionModel(DecisionTreeModel): """ Model fitted by DecisionTreeRegressor. + + .. versionadded:: 1.4.0 """ @@ -282,6 +323,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -336,6 +379,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, @@ -353,6 +397,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -360,12 +405,14 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -373,12 +420,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setNumTrees(self, value): """ Sets the value of :py:attr:`numTrees`. @@ -386,12 +435,14 @@ def setNumTrees(self, value): self._paramMap[self.numTrees] = value return self + @since("1.4.0") def getNumTrees(self): """ Gets the value of numTrees or its default value. """ return self.getOrDefault(self.numTrees) + @since("1.4.0") def setFeatureSubsetStrategy(self, value): """ Sets the value of :py:attr:`featureSubsetStrategy`. @@ -399,6 +450,7 @@ def setFeatureSubsetStrategy(self, value): self._paramMap[self.featureSubsetStrategy] = value return self + @since("1.4.0") def getFeatureSubsetStrategy(self): """ Gets the value of featureSubsetStrategy or its default value. @@ -409,6 +461,8 @@ def getFeatureSubsetStrategy(self): class RandomForestRegressionModel(TreeEnsembleModels): """ Model fitted by RandomForestRegressor. + + .. versionadded:: 1.4.0 """ @@ -435,6 +489,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -481,6 +537,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -498,6 +555,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTRegressionModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -505,12 +563,14 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. """ return self.getOrDefault(self.lossType) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -518,12 +578,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. @@ -531,6 +593,7 @@ def setStepSize(self, value): self._paramMap[self.stepSize] = value return self + @since("1.4.0") def getStepSize(self): """ Gets the value of stepSize or its default value. @@ -541,6 +604,8 @@ def getStepSize(self): class GBTRegressionModel(TreeEnsembleModels): """ Model fitted by GBTRegressor. + + .. versionadded:: 1.4.0 """ From 0ded87a4d49d4484e202bd2ec781821b57b5882c Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:47:21 -0700 Subject: [PATCH 0656/1824] [SPARK-10281] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.clustering Author: Yu ISHIKAWA Closes #8691 from yu-iskw/SPARK-10281. --- python/pyspark/ml/clustering.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index cb4c16e25a7a..7bb8ab94e17d 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -26,8 +27,11 @@ class KMeansModel(JavaModel): """ Model fitted by KMeans. + + .. versionadded:: 1.5.0 """ + @since("1.5.0") def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] @@ -55,6 +59,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -88,6 +94,7 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only + @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ @@ -99,6 +106,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setK(self, value): """ Sets the value of :py:attr:`k`. @@ -110,12 +118,14 @@ def setK(self, value): self._paramMap[self.k] = value return self + @since("1.5.0") def getK(self): """ Gets the value of `k` """ return self.getOrDefault(self.k) + @since("1.5.0") def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. @@ -130,12 +140,14 @@ def setInitMode(self, value): self._paramMap[self.initMode] = value return self + @since("1.5.0") def getInitMode(self): """ Gets the value of `initMode` """ return self.getOrDefault(self.initMode) + @since("1.5.0") def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. @@ -147,6 +159,7 @@ def setInitSteps(self, value): self._paramMap[self.initSteps] = value return self + @since("1.5.0") def getInitSteps(self): """ Gets the value of `initSteps` From 39b44cb52eb225469eb4ccdf696f0bc6405b9184 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:48:45 -0700 Subject: [PATCH 0657/1824] [SPARK-10278] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.tree Author: Yu ISHIKAWA Closes #8685 from yu-iskw/SPARK-10278. --- python/pyspark/mllib/tree.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 372b86a7c95d..0001b60093a6 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,7 +19,7 @@ import random -from pyspark import SparkContext, RDD +from pyspark import SparkContext, RDD, since from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -30,6 +30,11 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): + """TreeEnsembleModel + + .. versionadded:: 1.3.0 + """ + @since("1.3.0") def predict(self, x): """ Predict values for a single data point or an RDD of points using @@ -45,12 +50,14 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.3.0") def numTrees(self): """ Get number of trees in ensemble. """ return self.call("numTrees") + @since("1.3.0") def totalNumNodes(self): """ Get total number of nodes, summed over all trees in the @@ -62,6 +69,7 @@ def __repr__(self): """ Summary of model """ return self._java_model.toString() + @since("1.3.0") def toDebugString(self): """ Full model """ return self._java_model.toDebugString() @@ -72,7 +80,10 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): .. note:: Experimental A decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ + @since("1.1.0") def predict(self, x): """ Predict the label of one or more examples. @@ -90,16 +101,23 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.1.0") def numNodes(self): + """Get number of nodes in tree, including leaf nodes.""" return self._java_model.numNodes() + @since("1.1.0") def depth(self): + """Get depth of tree. + E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + """ return self._java_model.depth() def __repr__(self): """ summary of model. """ return self._java_model.toString() + @since("1.2.0") def toDebugString(self): """ full model. """ return self._java_model.toDebugString() @@ -115,6 +133,8 @@ class DecisionTree(object): Learning algorithm for a decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ @classmethod @@ -127,6 +147,7 @@ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, m return DecisionTreeModel(model) @classmethod + @since("1.1.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -185,6 +206,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @classmethod + @since("1.1.0") def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -239,6 +261,8 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a random forest model. + + .. versionadded:: 1.2.0 """ @classmethod @@ -252,6 +276,8 @@ class RandomForest(object): Learning algorithm for a random forest model for classification or regression. + + .. versionadded:: 1.2.0 """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -271,6 +297,7 @@ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees, return RandomForestModel(model) @classmethod + @since("1.2.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): @@ -352,6 +379,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, maxDepth, maxBins, seed) @classmethod + @since("1.2.0") def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ @@ -418,6 +446,8 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a gradient-boosted tree model. + + .. versionadded:: 1.3.0 """ @classmethod @@ -431,6 +461,8 @@ class GradientBoostedTrees(object): Learning algorithm for a gradient boosted trees model for classification or regression. + + .. versionadded:: 1.3.0 """ @classmethod @@ -443,6 +475,7 @@ def _train(cls, data, algo, categoricalFeaturesInfo, return GradientBoostedTreesModel(model) @classmethod + @since("1.3.0") def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): @@ -505,6 +538,7 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss, numIterations, learningRate, maxDepth, maxBins) @classmethod + @since("1.3.0") def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): From 4a0b56e8dbb3713b16e58738201d838ffc4b258b Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:00 -0700 Subject: [PATCH 0658/1824] [SPARK-10279] [MLLIB] [PYSPARK] [DOCS] Add @since annotation to pyspark.mllib.util Author: Yu ISHIKAWA Closes #8689 from yu-iskw/SPARK-10279. --- python/pyspark/mllib/util.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 10a1e4b3eb0f..39bc6586dd58 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -23,7 +23,7 @@ xrange = range basestring = str -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -32,6 +32,8 @@ class MLUtils(object): """ Helper methods to load, save and pre-process data used in MLlib. + + .. versionadded:: 1.0.0 """ @staticmethod @@ -69,6 +71,7 @@ def _convert_labeled_point_to_libsvm(p): return " ".join(items) @staticmethod + @since("1.0.0") def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None): """ Loads labeled data in the LIBSVM format into an RDD of @@ -123,6 +126,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod + @since("1.0.0") def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. @@ -147,6 +151,7 @@ def saveAsLibSVMFile(data, dir): lines.saveAsTextFile(dir) @staticmethod + @since("1.1.0") def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. @@ -172,6 +177,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) @staticmethod + @since("1.5.0") def appendBias(data): """ Returns a new vector with `1.0` (bias) appended to @@ -186,6 +192,7 @@ def appendBias(data): return _convert_to_vector(np.append(vec.toArray(), 1.0)) @staticmethod + @since("1.5.0") def loadVectors(sc, path): """ Loads vectors saved using `RDD[Vector].saveAsTextFile` @@ -197,6 +204,8 @@ def loadVectors(sc, path): class Saveable(object): """ Mixin for models and transformers which may be saved as files. + + .. versionadded:: 1.3.0 """ def save(self, sc, path): @@ -222,9 +231,13 @@ class JavaSaveable(Saveable): """ Mixin for models that provide save() through their Scala implementation. + + .. versionadded:: 1.3.0 """ + @since("1.3.0") def save(self, sc, path): + """Save this model to the given path.""" if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): @@ -235,6 +248,8 @@ def save(self, sc, path): class Loader(object): """ Mixin for classes which can load saved models from files. + + .. versionadded:: 1.3.0 """ @classmethod @@ -256,6 +271,8 @@ class JavaLoader(Loader): """ Mixin for classes which can load saved models using its Scala implementation. + + .. versionadded:: 1.3.0 """ @classmethod @@ -280,15 +297,21 @@ def _load_java(cls, sc, path): return java_obj.load(sc._jsc.sc(), path) @classmethod + @since("1.3.0") def load(cls, sc, path): + """Load a model from the given path.""" java_model = cls._load_java(sc, path) return cls(java_model) class LinearDataGenerator(object): - """Utils for generating linear data""" + """Utils for generating linear data. + + .. versionadded:: 1.5.0 + """ @staticmethod + @since("1.5.0") def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps): """ @@ -311,6 +334,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, xVariance, int(nPoints), int(seed), float(eps))) @staticmethod + @since("1.5.0") def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): """ From c74d38fd8faf8cba981cf934341d24b9a3167025 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:46 -0700 Subject: [PATCH 0659/1824] [SPARK-10274] [MLLIB] Add @since annotation to pyspark.mllib.fpm Author: Yu ISHIKAWA Closes #8665 from yu-iskw/SPARK-10274. --- python/pyspark/mllib/fpm.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index bdc4a132b1b1..bdabba9602a8 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -19,7 +19,7 @@ from numpy import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc @@ -41,8 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + + .. versionadded:: 1.4.0 """ + @since("1.4.0") def freqItemsets(self): """ Returns the frequent itemsets of this model. @@ -55,9 +58,12 @@ class FPGrowth(object): .. note:: Experimental A Parallel FP-growth algorithm to mine frequent itemsets. + + .. versionadded:: 1.4.0 """ @classmethod + @since("1.4.0") def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. @@ -74,6 +80,8 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): """ Represents an (items, freq) tuple. + + .. versionadded:: 1.4.0 """ From 268088b899e6e165e746aed87840d47bfaf50c43 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:51:19 -0700 Subject: [PATCH 0660/1824] [SPARK-10282] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.recommendation Author: Yu ISHIKAWA Closes #8692 from yu-iskw/SPARK-10282. --- python/pyspark/ml/recommendation.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b06099ac0aee..ec5748a1cfe9 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -80,6 +81,8 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=3.19...) >>> predictions[2] Row(user=2, item=0, prediction=-1.15...) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -122,6 +125,7 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): @@ -137,6 +141,7 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem def _create_model(self, java_model): return ALSModel(java_model) + @since("1.4.0") def setRank(self, value): """ Sets the value of :py:attr:`rank`. @@ -144,12 +149,14 @@ def setRank(self, value): self._paramMap[self.rank] = value return self + @since("1.4.0") def getRank(self): """ Gets the value of rank or its default value. """ return self.getOrDefault(self.rank) + @since("1.4.0") def setNumUserBlocks(self, value): """ Sets the value of :py:attr:`numUserBlocks`. @@ -157,12 +164,14 @@ def setNumUserBlocks(self, value): self._paramMap[self.numUserBlocks] = value return self + @since("1.4.0") def getNumUserBlocks(self): """ Gets the value of numUserBlocks or its default value. """ return self.getOrDefault(self.numUserBlocks) + @since("1.4.0") def setNumItemBlocks(self, value): """ Sets the value of :py:attr:`numItemBlocks`. @@ -170,12 +179,14 @@ def setNumItemBlocks(self, value): self._paramMap[self.numItemBlocks] = value return self + @since("1.4.0") def getNumItemBlocks(self): """ Gets the value of numItemBlocks or its default value. """ return self.getOrDefault(self.numItemBlocks) + @since("1.4.0") def setNumBlocks(self, value): """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. @@ -183,6 +194,7 @@ def setNumBlocks(self, value): self._paramMap[self.numUserBlocks] = value self._paramMap[self.numItemBlocks] = value + @since("1.4.0") def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. @@ -190,12 +202,14 @@ def setImplicitPrefs(self, value): self._paramMap[self.implicitPrefs] = value return self + @since("1.4.0") def getImplicitPrefs(self): """ Gets the value of implicitPrefs or its default value. """ return self.getOrDefault(self.implicitPrefs) + @since("1.4.0") def setAlpha(self, value): """ Sets the value of :py:attr:`alpha`. @@ -203,12 +217,14 @@ def setAlpha(self, value): self._paramMap[self.alpha] = value return self + @since("1.4.0") def getAlpha(self): """ Gets the value of alpha or its default value. """ return self.getOrDefault(self.alpha) + @since("1.4.0") def setUserCol(self, value): """ Sets the value of :py:attr:`userCol`. @@ -216,12 +232,14 @@ def setUserCol(self, value): self._paramMap[self.userCol] = value return self + @since("1.4.0") def getUserCol(self): """ Gets the value of userCol or its default value. """ return self.getOrDefault(self.userCol) + @since("1.4.0") def setItemCol(self, value): """ Sets the value of :py:attr:`itemCol`. @@ -229,12 +247,14 @@ def setItemCol(self, value): self._paramMap[self.itemCol] = value return self + @since("1.4.0") def getItemCol(self): """ Gets the value of itemCol or its default value. """ return self.getOrDefault(self.itemCol) + @since("1.4.0") def setRatingCol(self, value): """ Sets the value of :py:attr:`ratingCol`. @@ -242,12 +262,14 @@ def setRatingCol(self, value): self._paramMap[self.ratingCol] = value return self + @since("1.4.0") def getRatingCol(self): """ Gets the value of ratingCol or its default value. """ return self.getOrDefault(self.ratingCol) + @since("1.4.0") def setNonnegative(self, value): """ Sets the value of :py:attr:`nonnegative`. @@ -255,6 +277,7 @@ def setNonnegative(self, value): self._paramMap[self.nonnegative] = value return self + @since("1.4.0") def getNonnegative(self): """ Gets the value of nonnegative or its default value. @@ -265,14 +288,18 @@ def getNonnegative(self): class ALSModel(JavaModel): """ Model fitted by ALS. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def rank(self): """rank of the matrix factorization model""" return self._call_java("rank") @property + @since("1.4.0") def userFactors(self): """ a DataFrame that stores user factors in two columns: `id` and @@ -281,6 +308,7 @@ def userFactors(self): return self._call_java("userFactors") @property + @since("1.4.0") def itemFactors(self): """ a DataFrame that stores item factors in two columns: `id` and From e51345e1e04e439827a07c95887d14ba38333057 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 17 Sep 2015 09:17:43 -0700 Subject: [PATCH 0661/1824] [SPARK-10077] [DOCS] [ML] Add package info for java of ml/feature Should be the same as SPARK-7808 but use Java for the code example. It would be great to add package doc for `spark.ml.feature`. Author: Holden Karau Closes #8740 from holdenk/SPARK-10077-JAVA-PACKAGE-DOC-FOR-SPARK.ML.FEATURE. --- .../apache/spark/ml/feature/package-info.java | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java new file mode 100644 index 000000000000..c22d2e0cd2d9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * Feature transformers + * + * The `ml.feature` package provides common feature transformers that help convert raw data or + * features into more suitable forms for model fitting. + * Most feature transformers are implemented as {@link org.apache.spark.ml.Transformer}s, which + * transforms one {@link org.apache.spark.sql.DataFrame} into another, e.g., + * {@link org.apache.spark.feature.HashingTF}. + * Some feature transformers are implemented as {@link org.apache.spark.ml.Estimator}}s, because the + * transformation requires some aggregated information of the dataset, e.g., document + * frequencies in {@link org.apache.spark.ml.feature.IDF}. + * For those feature transformers, calling {@link org.apache.spark.ml.Estimator#fit} is required to + * obtain the model first, e.g., {@link org.apache.spark.ml.feature.IDFModel}, in order to apply + * transformation. + * The transformation is usually done by appending new columns to the input + * {@link org.apache.spark.sql.DataFrame}, so all input columns are carried over. + * + * We try to make each transformer minimal, so it becomes flexible to assemble feature + * transformation pipelines. + * {@link org.apache.spark.ml.Pipeline} can be used to chain feature transformers, and + * {@link org.apache.spark.ml.feature.VectorAssembler} can be used to combine multiple feature + * transformations, for example: + * + *
    + * 
    + *   import java.util.Arrays;
    + *
    + *   import org.apache.spark.api.java.JavaRDD;
    + *   import static org.apache.spark.sql.types.DataTypes.*;
    + *   import org.apache.spark.sql.types.StructType;
    + *   import org.apache.spark.sql.DataFrame;
    + *   import org.apache.spark.sql.RowFactory;
    + *   import org.apache.spark.sql.Row;
    + *
    + *   import org.apache.spark.ml.feature.*;
    + *   import org.apache.spark.ml.Pipeline;
    + *   import org.apache.spark.ml.PipelineStage;
    + *   import org.apache.spark.ml.PipelineModel;
    + *
    + *  // a DataFrame with three columns: id (integer), text (string), and rating (double).
    + *  StructType schema = createStructType(
    + *    Arrays.asList(
    + *      createStructField("id", IntegerType, false),
    + *      createStructField("text", StringType, false),
    + *      createStructField("rating", DoubleType, false)));
    + *  JavaRDD rowRDD = jsc.parallelize(
    + *    Arrays.asList(
    + *      RowFactory.create(0, "Hi I heard about Spark", 3.0),
    + *      RowFactory.create(1, "I wish Java could use case classes", 4.0),
    + *      RowFactory.create(2, "Logistic regression models are neat", 4.0)));
    + *  DataFrame df = jsql.createDataFrame(rowRDD, schema);
    + *  // define feature transformers
    + *  RegexTokenizer tok = new RegexTokenizer()
    + *    .setInputCol("text")
    + *    .setOutputCol("words");
    + *  StopWordsRemover sw = new StopWordsRemover()
    + *    .setInputCol("words")
    + *    .setOutputCol("filtered_words");
    + *  HashingTF tf = new HashingTF()
    + *    .setInputCol("filtered_words")
    + *    .setOutputCol("tf")
    + *    .setNumFeatures(10000);
    + *  IDF idf = new IDF()
    + *    .setInputCol("tf")
    + *    .setOutputCol("tf_idf");
    + *  VectorAssembler assembler = new VectorAssembler()
    + *    .setInputCols(new String[] {"tf_idf", "rating"})
    + *    .setOutputCol("features");
    + *
    + *  // assemble and fit the feature transformation pipeline
    + *  Pipeline pipeline = new Pipeline()
    + *    .setStages(new PipelineStage[] {tok, sw, tf, idf, assembler});
    + *  PipelineModel model = pipeline.fit(df);
    + *
    + *  // save transformed features with raw data
    + *  model.transform(df)
    + *    .select("id", "text", "rating", "features")
    + *    .write().format("parquet").save("/output/path");
    + * 
    + * 
    + * + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn. + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire + * input dataset, while MLlib's feature transformers operate lazily on individual columns, + * which is more efficient and flexible to handle large and complex datasets. + * + * @see + * scikit-learn.preprocessing + */ +package org.apache.spark.ml.feature; From 2a508df20d03b3d4a3c05b65fb02d849bc080ef9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 09:21:21 -0700 Subject: [PATCH 0662/1824] [SPARK-10459] [SQL] Do not need to have ConvertToSafe for PythonUDF JIRA: https://issues.apache.org/jira/browse/SPARK-10459 As mentioned in the JIRA, `PythonUDF` actually could process `UnsafeRow`. Specially, the rows in `childResults` in `BatchPythonEvaluation` will be projected to a `MutableRow`. So I think we can enable `canProcessUnsafeRows` for `BatchPythonEvaluation` and get rid of redundant `ConvertToSafe`. Author: Liang-Chi Hsieh Closes #8616 from viirya/pyudf-unsafe. --- .../scala/org/apache/spark/sql/execution/pythonUDFs.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 5a58d846ad80..d0411da6fdf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -337,6 +337,10 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { val childResults = child.execute().map(_.copy()) From c88bb5df94f9696677c3a429472114bc66f32a52 Mon Sep 17 00:00:00 2001 From: "yangping.wu" Date: Thu, 17 Sep 2015 09:52:40 -0700 Subject: [PATCH 0663/1824] [SPARK-10660] Doc describe error in the "Running Spark on YARN" page MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the Configuration section, the **spark.yarn.driver.memoryOverhead** and **spark.yarn.am.memoryOverhead**‘s default value should be "driverMemory * 0.10, with minimum of 384" and "AM memory * 0.10, with minimum of 384" respectively. Because from Spark 1.4.0, the **MEMORY_OVERHEAD_FACTOR** is set to 0.1.0, not 0.07. Author: yangping.wu Closes #8797 from 397090770/SparkOnYarnDocError. --- docs/running-on-yarn.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d1244323edff..3a961d245f3d 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,14 +211,14 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.driver.memoryOverhead - driverMemory * 0.07, with minimum of 384 + driverMemory * 0.10, with minimum of 384 The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead - AM memory * 0.07, with minimum of 384 + AM memory * 0.10, with minimum of 384 Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. From 136c77d8bbf48f7c45dd7c3fbe261a0476f455fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 10:02:15 -0700 Subject: [PATCH 0664/1824] [SPARK-10642] [PYSPARK] Fix crash when calling rdd.lookup() on tuple keys JIRA: https://issues.apache.org/jira/browse/SPARK-10642 When calling `rdd.lookup()` on a RDD with tuple keys, `portable_hash` will return a long. That causes `DAGScheduler.submitJob` to throw `java.lang.ClassCastException: java.lang.Long cannot be cast to java.lang.Integer`. Author: Liang-Chi Hsieh Closes #8796 from viirya/fix-pyrdd-lookup. --- python/pyspark/rdd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9ef60a7e2c84..ab5aab1e115f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -84,7 +84,7 @@ def portable_hash(x): h ^= len(x) if h == -1: h = -2 - return h + return int(h) return hash(x) @@ -2192,6 +2192,9 @@ def lookup(self, key): [42] >>> sorted.lookup(1024) [] + >>> rdd2 = sc.parallelize([(('a', 'b'), 'c')]).groupByKey() + >>> list(rdd2.lookup(('a', 'b'))[0]) + ['c'] """ values = self.filter(lambda kv: kv[0] == key).values() From 81b4db374dd61b6f1c30511c70b6ab2a52c68faa Mon Sep 17 00:00:00 2001 From: Josiah Samuel Date: Thu, 17 Sep 2015 10:18:21 -0700 Subject: [PATCH 0665/1824] [SPARK-10172] [CORE] disable sort in HistoryServer webUI This pull request is to address the JIRA SPARK-10172 (History Server web UI gets messed up when sorting on any column). The content of the table gets messed up due to the rowspan attribute of the table data(cell) during sorting. The current table sort library used in SparkUI (sorttable.js) doesn't support/handle cells(td) with rowspans. The fix will disable the table sort in the web UI, when there are jobs listed with multiple attempts. Author: Josiah Samuel Closes #8506 from josiahsams/SPARK-10172. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0830cc1ba124..b347cb3be69f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -51,7 +51,10 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val hasMultipleAttempts = appsToShow.exists(_.attempts.size > 1) val appTable = if (hasMultipleAttempts) { - UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, appsToShow) + // Sorting is disable here as table sort on rowspan has issues. + // ref. SPARK-10172 + UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, + appsToShow, sortable = false) } else { UIUtils.listingTable(appHeader, appRow, appsToShow) } From 36d8b278d82e788bf583e8438fac524d0023311d Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 17 Sep 2015 10:25:18 -0700 Subject: [PATCH 0666/1824] [SPARK-10531] [CORE] AppId is set as AppName in status rest api Verify it manually. Author: Jeff Zhang Closes #8688 from zjffdu/SPARK-10531. --- .../main/scala/org/apache/spark/SparkContext.scala | 1 + .../spark/deploy/history/FsHistoryProvider.scala | 9 ++++----- .../scala/org/apache/spark/deploy/master/Master.scala | 2 +- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 11 ++++++----- .../scala/org/apache/spark/ui/UISeleniumSuite.scala | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a2f34eafa2c3..9c3218719f7f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -521,6 +521,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _applicationId = _taskScheduler.applicationId() _applicationAttemptId = taskScheduler.applicationAttemptId() _conf.set("spark.app.id", _applicationId) + _ui.foreach(_.setAppId(_applicationId)) _env.blockManager.initialize(_applicationId) // The metrics system for Driver need to be set spark.app.id to app ID. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a5755eac3639..8eb2ba1e8683 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -146,16 +146,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } val appListener = new ApplicationEventListener() replayBus.addListener(appListener) - val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - appInfo.map { info => - ui.setAppName(s"${info.name} ($appId)") - + val appAttemptInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), + replayBus) + appAttemptInfo.map { info => val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) // make sure to set admin acls before view acls so they are properly picked up diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 26904d39a9be..d518e92133aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -944,7 +944,7 @@ private[deploy] class Master( val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName + status, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) + appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { replayBus.replay(logInput, eventLogFile, maybeTruncated) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index d8b90568b7b9..99085ada9f0a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -56,6 +56,8 @@ private[spark] class SparkUI private ( val stagesTab = new StagesTab(this) + var appId: String = _ + /** Initialize all components of the server. */ def initialize() { attachTab(new JobsTab(this)) @@ -75,9 +77,8 @@ private[spark] class SparkUI private ( def getAppName: String = appName - /** Set the app name for this UI. */ - def setAppName(name: String) { - appName = name + def setAppId(id: String): Unit = { + appId = id } /** Stop the server behind this web interface. Only valid after bind(). */ @@ -94,12 +95,12 @@ private[spark] class SparkUI private ( private[spark] def appUIAddress = s"http://$appUIHostPort" def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == appName) Some(this) else None + if (appId == this.appId) Some(this) else None } def getApplicationInfoList: Iterator[ApplicationInfo] = { Iterator(new ApplicationInfo( - id = appName, + id = appId, name = appName, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 22e30ecaf053..18eec7da9763 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -658,6 +658,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/test/" + path) + new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } From e0dc2bc232206d2f4da4278502c1f88babc8b55a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 17 Sep 2015 11:05:30 -0700 Subject: [PATCH 0667/1824] [SPARK-10650] Clean before building docs The [published docs for 1.5.0](http://spark.apache.org/docs/1.5.0/api/java/org/apache/spark/streaming/) have a bunch of test classes in them. The only way I can reproduce this is to `test:compile` before running `unidoc`. To prevent this from happening again, I've added a clean before doc generation. Author: Michael Armbrust Closes #8787 from marmbrus/testsInDocs. --- docs/_plugins/copy_api_dirs.rb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 15ceda11a8a8..01718d98dffe 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -26,12 +26,15 @@ curr_dir = pwd cd("..") - puts "Running 'build/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl compile unidoc` + puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." + puts `build/sbt -Pkinesis-asl clean compile unidoc` puts "Moving back into docs dir." cd("docs") + puts "Removing old docs" + puts `rm -rf api` + # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. source = "../target/scala-2.10/unidoc" From aad644fbe29151aec9004817d42e4928bdb326f3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 17 Sep 2015 11:14:52 -0700 Subject: [PATCH 0668/1824] [SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type https://issues.apache.org/jira/browse/SPARK-10639 Author: Yin Huai Closes #8788 from yhuai/udafConversion. --- .../sql/catalyst/CatalystTypeConverters.scala | 7 +- .../spark/sql/RandomDataGenerator.scala | 16 ++- .../spark/sql/execution/aggregate/udaf.scala | 37 +++++- .../org/apache/spark/sql/QueryTest.scala | 21 ++-- .../spark/sql/UserDefinedTypeSuite.scala | 11 ++ .../execution/AggregationQuerySuite.scala | 108 +++++++++++++++++- 6 files changed, 188 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 966623ed017b..f25591794abd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -138,8 +138,13 @@ object CatalystTypeConverters { private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + // toCatalyst (it calls toCatalystImpl) will do null check. override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) - override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + + override def toScala(catalystValue: Any): Any = { + if (catalystValue == null) null else udt.deserialize(catalystValue) + } + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column, udt.sqlType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 4025cbcec101..e48395028e39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -108,7 +108,21 @@ object RandomDataGenerator { arr }) case BooleanType => Some(() => rand.nextBoolean()) - case DateType => Some(() => new java.sql.Date(rand.nextInt())) + case DateType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt) + } + Some(generator) case TimestampType => val generator = () => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index d43d3dd9ffaa..1114fe6552bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < getters.length) { getters(i) = dataTypes(i) match { + case NullType => + (row: InternalRow, ordinal: Int) => null + case BooleanType => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal) @@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale) + case DateType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case other => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.get(ordinal, other) @@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < setters.length) { setters(i) = dataTypes(i) match { + case NullType => + (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) + case b: BooleanType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { @@ -150,9 +164,23 @@ sealed trait BufferSetterGetterUtils { case dt: DecimalType => val precision = dt.precision + (row: MutableRow, ordinal: Int, value: Any) => + // To make it work with UnsafeRow, we cannot use setNullAt. + // Please see the comment of UnsafeRow's setDecimal. + row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + + case DateType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { - row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case TimestampType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) } else { row.setNullAt(ordinal) } @@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } + toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i))) } @@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF( } } + private[this] lazy val outputToCatalystConverter: Any => Any = { + CatalystTypeConverters.createToCatalystConverter(dataType) + } + // This buffer is only used at executor side. private[this] var inputAggregateBuffer: InputAggregationBuffer = null @@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF( override def eval(buffer: InternalRow): Any = { evalAggregateBuffer.underlyingInputBuffer = buffer - udaf.evaluate(evalAggregateBuffer) + outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer)) } override def toString: String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index cada03e9ac6b..e3c5a426671d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -115,19 +115,26 @@ object QueryTest { */ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for // equality test. - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } + val converted: Seq[Row] = answer.map(prepareRow) if (!isSorted) converted.sortBy(_.toString()) else converted } val sparkAnswer = try df.collect().toSeq catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 46d87843dfa4..7992fd59ff4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -163,4 +164,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { assert(new MyDenseVectorUDT().typeName === "mydensevector") assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") } + + test("Catalyst type converter null handling for UDTs") { + val udt = new MyDenseVectorUDT() + val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) + assert(toScalaConverter(null) === null) + + val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt) + assert(toCatalystConverter(null) === null) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index a73b1bd52c09..24b1846923c7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,13 +17,55 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { + + def inputSchema: StructType = schema + + def bufferSchema: StructType = schema + + def dataType: DataType = schema + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + (0 until schema.length).foreach { i => + buffer.update(i, null) + } + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!input.isNullAt(0) && input.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer.update(i, input.get(i)) + } + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer1.update(i, buffer2.get(i)) + } + } + } + + def evaluate(buffer: Row): Any = { + Row.fromSeq(buffer.toSeq) + } +} + abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -508,6 +550,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } } + + test("udaf with all data types") { + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + // Right now, we will use SortBasedAggregate to handle UDAFs. + // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use + // UnsafeRow as the aggregation buffer. While, dataTypes will trigger + // SortBasedAggregate to use a safe row as the aggregation buffer. + Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes => + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + // The schema used for data generator. + val schemaForGenerator = StructType(fields) + // The schema used for the DataFrame df. + val schema = StructType(StructField("id", IntegerType) +: fields) + + logInfo(s"Testing schema: ${schema.treeString}") + + val udaf = new ScalaAggregateFunction(schema) + // Generate data at the driver side. We need to materialize the data first and then + // create RDD. + val maybeDataGenerator = + RandomDataGenerator.forType( + dataType = schemaForGenerator, + nullable = true, + seed = Some(System.nanoTime())) + val dataGenerator = + maybeDataGenerator + .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator")) + val data = (1 to 50).map { i => + dataGenerator.apply() match { + case row: Row => Row.fromSeq(i +: row.toSeq) + case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null)) + case other => + fail(s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated.") + } + } + + // Create a DF for the schema with random data. + val rdd = sqlContext.sparkContext.parallelize(data, 1) + val df = sqlContext.createDataFrame(rdd, schema) + + val allColumns = df.schema.fields.map(f => col(f.name)) + val expectedAnaswer = + data + .find(r => r.getInt(0) == 50) + .getOrElse(fail("A row with id 50 should be the expected answer.")) + checkAnswer( + df.groupBy().agg(udaf(allColumns: _*)), + // udaf returns a Row as the output value. + Row(expectedAnaswer) + ) + } + } } class SortBasedAggregationQuerySuite extends AggregationQuerySuite { From 64743870f23bffb8d96dcc8a0181c1452782a151 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 11:24:38 -0700 Subject: [PATCH 0669/1824] [SPARK-10394] [ML] Make GBTParams use shared stepSize ```GBTParams``` has ```stepSize``` as learning rate currently. ML has shared param class ```HasStepSize```, ```GBTParams``` can extend from it rather than duplicated implementation. Author: Yanbo Liang Closes #8552 from yanboliang/spark-10394. --- .../org/apache/spark/ml/tree/treeParams.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d29f5253c9c3..42e74ce6d2c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -365,17 +365,7 @@ private[ml] object RandomForestParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { - - /** - * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each - * estimator. - * (default = 0.1) - * @group param - */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + - " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", - ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -393,11 +383,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - /** @group setParam */ + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group setParam + */ def setStepSize(value: Double): this.type = set(stepSize, value) - /** @group getParam */ - final def getStepSize: Double = $(stepSize) + override def validateParams(): Unit = { + require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)( + getStepSize), "GBT parameter stepSize should be in interval (0, 1], " + + s"but it given invalid value $getStepSize.") + } /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( From f1c911552cf5d0d60831c79c1881016293aec66c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 17 Sep 2015 11:40:24 -0700 Subject: [PATCH 0670/1824] [SPARK-10657] Remove SCP-based Jenkins log archiving As of https://issues.apache.org/jira/browse/SPARK-7561, we no longer need to use our custom SCP-based mechanism for archiving Jenkins logs on the master machine; this has been superseded by the use of a Jenkins plugin which archives the logs and provides public links to view them. Per shaneknapp, we should remove this log syncing mechanism if it is no longer necessary; removing the need to SCP from the Jenkins workers to the masters is a desired step as part of some larger Jenkins infra refactoring. Author: Josh Rosen Closes #8793 from JoshRosen/remove-jenkins-ssh-to-master. --- dev/run-tests-jenkins | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3be78575e70f..d3b05fa6df0c 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -116,39 +116,6 @@ function post_message () { fi } -function send_archived_logs () { - echo "Archiving unit tests logs..." - - local log_files=$( - find .\ - -name "unit-tests.log" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.wrong" - ) - - if [ -z "$log_files" ]; then - echo "> No log files found." >&2 - else - local log_archive="unit-tests-logs.tar.gz" - echo "$log_files" | xargs tar czf ${log_archive} - - local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER} - local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive}) - local scp_status="$?" - - if [ "$scp_status" -ne 0 ]; then - echo "Failed to send archived unit tests logs to Jenkins master." >&2 - echo "> scp_status: ${scp_status}" >&2 - echo "> scp_output: ${scp_output}" >&2 - else - echo "> Send successful." - fi - - rm -f ${log_archive} - fi -} - # post start message { start_message="\ @@ -244,8 +211,6 @@ done test_result_note=" * This patch **fails $failing_test**." fi - - send_archived_logs } # post end message From 4fbf3328692e876f39ea78494510f9d9c5a53f15 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 17 Sep 2015 14:09:06 -0700 Subject: [PATCH 0671/1824] [SPARK-9698] [ML] Add RInteraction transformer for supporting R-style feature interactions This is a pre-req for supporting the ":" operator in the RFormula feature transformer. Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit mengxr Author: Eric Liang Closes #7987 from ericl/interaction. --- .../apache/spark/ml/feature/Interaction.scala | 278 ++++++++++++++++++ .../spark/ml/feature/InteractionSuite.scala | 165 +++++++++++ 2 files changed, 443 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala new file mode 100644 index 000000000000..9194763fb32f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.Transformer +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the feature interaction transform. This transformer takes in Double and Vector type + * columns and outputs a flattened vector of their feature interactions. To handle interaction, + * we first one-hot encode any nominal features. Then, a vector of the feature cross-products is + * produced. + * + * For example, given the input feature values `Double(2)` and `Vector(3, 4)`, the output would be + * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal + * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. + */ +@Experimental +class Interaction(override val uid: String) extends Transformer + with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + validateParams() + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) + } + + override def transform(dataset: DataFrame): DataFrame = { + validateParams() + val inputFeatures = $(inputCols).map(c => dataset.schema(c)) + val featureEncoders = getFeatureEncoders(inputFeatures) + val featureAttrs = getFeatureAttrs(inputFeatures) + + def interactFunc = udf { row: Row => + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var size = 1 + indices += 0 + values += 1.0 + var featureIndex = row.length - 1 + while (featureIndex >= 0) { + val prevIndices = indices.result() + val prevValues = values.result() + val prevSize = size + val currentEncoder = featureEncoders(featureIndex) + indices = ArrayBuilder.make[Int] + values = ArrayBuilder.make[Double] + size *= currentEncoder.outputSize + currentEncoder.foreachNonzeroOutput(row(featureIndex), (i, a) => { + var j = 0 + while (j < prevIndices.length) { + indices += prevIndices(j) + i * prevSize + values += prevValues(j) * a + j += 1 + } + }) + featureIndex -= 1 + } + Vectors.sparse(size, indices.result(), values.result()).compressed + } + + val featureCols = inputFeatures.map { f => + f.dataType match { + case DoubleType => dataset(f.name) + case _: VectorUDT => dataset(f.name) + case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType) + } + } + dataset.select( + col("*"), + interactFunc(struct(featureCols: _*)).as($(outputCol), featureAttrs.toMetadata())) + } + + /** + * Creates a feature encoder for each input column, which supports efficient iteration over + * one-hot encoded feature values. See also the class-level comment of [[FeatureEncoder]]. + * + * @param features The input feature columns to create encoders for. + */ + private def getFeatureEncoders(features: Seq[StructField]): Array[FeatureEncoder] = { + def getNumFeatures(attr: Attribute): Int = { + attr match { + case nominal: NominalAttribute => + math.max(1, nominal.getNumValues.getOrElse( + throw new SparkException("Nominal features must have attr numValues defined."))) + case _ => + 1 // numeric feature + } + } + features.map { f => + val numFeatures = f.dataType match { + case _: NumericType | BooleanType => + Array(getNumFeatures(Attribute.fromStructField(f))) + case _: VectorUDT => + val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse( + throw new SparkException("Vector attributes must be defined for interaction.")) + attrs.map(getNumFeatures).toArray + } + new FeatureEncoder(numFeatures) + }.toArray + } + + /** + * Generates ML attributes for the output vector of all feature interactions. We make a best + * effort to generate reasonable names for output features, based on the concatenation of the + * interacting feature names and values delimited with `_`. When no feature name is specified, + * we fall back to using the feature index (e.g. `foo:bar_2_0` may indicate an interaction + * between the numeric `foo` feature and a nominal third feature from column `bar`. + * + * @param features The input feature columns to the Interaction transformer. + */ + private def getFeatureAttrs(features: Seq[StructField]): AttributeGroup = { + var featureAttrs: Seq[Attribute] = Nil + features.reverse.foreach { f => + val encodedAttrs = f.dataType match { + case _: NumericType | BooleanType => + val attr = Attribute.fromStructField(f) + encodedFeatureAttrs(Seq(attr), None) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(f) + encodedFeatureAttrs(group.attributes.get, Some(group.name)) + } + if (featureAttrs.isEmpty) { + featureAttrs = encodedAttrs + } else { + featureAttrs = encodedAttrs.flatMap { head => + featureAttrs.map { tail => + NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) + } + } + } + } + new AttributeGroup($(outputCol), featureAttrs.toArray) + } + + /** + * Generates the output ML attributes for a single input feature. Each output feature name has + * up to three parts: the group name, feature name, and category name (for nominal features), + * each separated by an underscore. + * + * @param inputAttrs The attributes of the input feature. + * @param groupName Optional name of the input feature group (for Vector type features). + */ + private def encodedFeatureAttrs( + inputAttrs: Seq[Attribute], + groupName: Option[String]): Seq[Attribute] = { + + def format( + index: Int, + attrName: Option[String], + categoryName: Option[String]): String = { + val parts = Seq(groupName, Some(attrName.getOrElse(index.toString)), categoryName) + parts.flatten.mkString("_") + } + + inputAttrs.zipWithIndex.flatMap { + case (nominal: NominalAttribute, i) => + if (nominal.values.isDefined) { + nominal.values.get.map( + v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v)))) + } else { + Array.tabulate(nominal.getNumValues.get)( + j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString)))) + } + case (a: Attribute, i) => + Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None))) + } + } + + override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + + override def validateParams(): Unit = { + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") + } +} + +/** + * This class performs on-the-fly one-hot encoding of features as you iterate over them. To + * indicate which input features should be one-hot encoded, an array of the feature counts + * must be passed in ahead of time. + * + * @param numFeatures Array of feature counts for each input feature. For nominal features this + * count is equal to the number of categories. For numeric features the count + * should be set to 1. + */ +private[ml] class FeatureEncoder(numFeatures: Array[Int]) { + assert(numFeatures.forall(_ > 0), "Features counts must all be positive.") + + /** The size of the output vector. */ + val outputSize = numFeatures.sum + + /** Precomputed offsets for the location of each output feature. */ + private val outputOffsets = { + val arr = new Array[Int](numFeatures.length) + var i = 1 + while (i < arr.length) { + arr(i) = arr(i - 1) + numFeatures(i - 1) + i += 1 + } + arr + } + + /** + * Given an input row of features, invokes the specific function for every non-zero output. + * + * @param value The row value to encode, either a Double or Vector. + * @param f The callback to invoke on each non-zero (index, value) output pair. + */ + def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { + case d: Double => + assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + val numOutputCols = numFeatures.head + if (numOutputCols > 1) { + assert( + d >= 0.0 && d == d.toInt && d < numOutputCols, + s"Values from column must be indices, but got $d.") + f(d.toInt, 1.0) + } else { + f(0, d) + } + case vec: Vector => + assert(numFeatures.length == vec.size, + s"Vector column size was ${vec.size}, expected ${numFeatures.length}") + vec.foreachActive { (i, v) => + val numOutputCols = numFeatures(i) + if (numOutputCols > 1) { + assert( + v >= 0.0 && v == v.toInt && v < numOutputCols, + s"Values from column must be indices, but got $v.") + f(outputOffsets(i) + v.toInt, 1.0) + } else { + f(outputOffsets(i), v) + } + } + case null => + throw new SparkException("Values to interact cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala new file mode 100644 index 000000000000..2beb62ca0823 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.functions.col + +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Interaction()) + } + + test("feature encoder") { + def encode(cardinalities: Array[Int], value: Any): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + val encoder = new FeatureEncoder(cardinalities) + encoder.foreachNonzeroOutput(value, (i, v) => { + indices += i + values += v + }) + Vectors.sparse(encoder.outputSize, indices.result(), values.result()).compressed + } + assert(encode(Array(1), 2.2) === Vectors.dense(2.2)) + assert(encode(Array(3), Vectors.dense(1)) === Vectors.dense(0, 1, 0)) + assert(encode(Array(1, 1), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2)) + assert(encode(Array(3, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2)) + assert(encode(Array(2, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2)) + assert(encode(Array(2, 1, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 0)) + intercept[SparkException] { encode(Array(1), "foo") } + intercept[SparkException] { encode(Array(1), null) } + intercept[AssertionError] { encode(Array(2), 2.2) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) } + intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(-1)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(3)) } + } + + test("numeric interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a:b_foo"), Some(1)), + new NumericAttribute(Some("a:b_bar"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("nominal interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as( + "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_up:b_foo"), Some(1)), + new NumericAttribute(Some("a_up:b_bar"), Some(2)), + new NumericAttribute(Some("a_down:b_foo"), Some(3)), + new NumericAttribute(Some("a_down:b_bar"), Some(4)), + new NumericAttribute(Some("a_left:b_foo"), Some(5)), + new NumericAttribute(Some("a_left:b_bar"), Some(6)))) + assert(attrs === expectedAttrs) + } + + test("default attr names") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0), + (1, Vectors.dense(1.0, 5.0), 10.0)) + ).toDF("a", "b", "c") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NominalAttribute.defaultAttr.withNumValues(2), + NumericAttribute.defaultAttr)) + val df = data.select( + col("a").as("a", NominalAttribute.defaultAttr.withNumValues(3).toMetadata()), + col("b").as("b", groupAttr.toMetadata()), + col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), + (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) + ).toDF("a", "b", "c", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_0:b_0_0:c"), Some(1)), + new NumericAttribute(Some("a_0:b_0_1:c"), Some(2)), + new NumericAttribute(Some("a_0:b_1:c"), Some(3)), + new NumericAttribute(Some("a_1:b_0_0:c"), Some(4)), + new NumericAttribute(Some("a_1:b_0_1:c"), Some(5)), + new NumericAttribute(Some("a_1:b_1:c"), Some(6)), + new NumericAttribute(Some("a_2:b_0_0:c"), Some(7)), + new NumericAttribute(Some("a_2:b_0_1:c"), Some(8)), + new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) + assert(attrs === expectedAttrs) + } +} From 0f5ef6dfa67a068606aff8ea9d1addfce73446eb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 17 Sep 2015 19:16:34 -0700 Subject: [PATCH 0672/1824] [SPARK-10674] [TESTS] Increase timeouts in SaslIntegrationSuite. 1s seems to trigger too many times on the jenkins build boxes, so increase the timeout and cross fingers. Author: Marcelo Vanzin Closes #8802 from vanzin/SPARK-10674 and squashes the following commits: 3c93117 [Marcelo Vanzin] Use java 7 syntax. d667d1b [Marcelo Vanzin] [SPARK-10674] [tests] Increase timeouts in SaslIntegrationSuite. --- .../spark/network/sasl/SaslIntegrationSuite.java | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 5cb0e4d4a645..c393a5e1e681 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -56,6 +56,11 @@ import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { + + // Use a long timeout to account for slow / overloaded build machines. In the normal case, + // tests should finish way before the timeout expires. + private final static long TIMEOUT_MS = 10_000; + static TransportServer server; static TransportConf conf; static TransportContext context; @@ -102,7 +107,7 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg } @@ -131,7 +136,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], 1000); + client.sendRpcSync(new byte[13], TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -139,7 +144,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -217,12 +222,12 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), 10000); + client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000); + byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); long streamId = stream.streamId; From 98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 21:37:10 -0700 Subject: [PATCH 0673/1824] [SPARK-8518] [ML] Log-linear models for survival analysis [Accelerated Failure Time (AFT) model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) is the most commonly used and easy to parallel method of survival analysis for censored survival data. It is the log-linear model based on the Weibull distribution of the survival time. Users can refer to the R function [```survreg```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) to compare the model and [```predict```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/predict.survreg.html) to compare the prediction. There are different kinds of model prediction, I have just select the type ```response``` which is default used for R. Author: Yanbo Liang Closes #8611 from yanboliang/spark-8518. --- .../ml/regression/AFTSurvivalRegression.scala | 449 ++++++++++++++++++ .../AFTSurvivalRegressionSuite.scala | 311 ++++++++++++ 2 files changed, 760 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala new file mode 100644 index 000000000000..5b25db651f56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.collection.mutable + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.storage.StorageLevel + +/** + * Params for accelerated failure time (AFT) regression. + */ +private[regression] trait AFTSurvivalRegressionParams extends Params + with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter + with HasTol with HasFitIntercept { + + /** + * Param for censor column name. + * The value of this column could be 0 or 1. + * If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored. + * @group param + */ + @Since("1.6.0") + final val censorCol: Param[String] = new Param(this, "censorCol", "censor column name") + + /** @group getParam */ + @Since("1.6.0") + def getCensorCol: String = $(censorCol) + setDefault(censorCol -> "censor") + + /** + * Param for quantile probabilities array. + * Values of the quantile probabilities array should be in the range [0, 1]. + * @group param + */ + @Since("1.6.0") + final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, + "quantileProbabilities", "quantile probabilities array", + (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1))) + + /** @group getParam */ + @Since("1.6.0") + def getQuantileProbabilities: Array[Double] = $(quantileProbabilities) + + /** Checks whether the input has quantile probabilities array. */ + protected[regression] def hasQuantileProbabilities: Boolean = { + isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0 + } + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + if (fitting) { + SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + +/** + * :: Experimental :: + * Fit a parametric survival regression model named accelerated failure time (AFT) model + * ([[https://en.wikipedia.org/wiki/Accelerated_failure_time_model]]) + * based on the Weibull distribution of the survival time. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("aftSurvReg")) + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setCensorCol(value: String): this.type = set(censorCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + @Since("1.6.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("1.6.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, + * and put it in an RDD with strong types. + */ + protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { + dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { + case Row(features: Vector, label: Double, censor: Double) => + AFTPoint(features, label, censor) + } + } + + @Since("1.6.0") + override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) + val instances = extractAFTPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val costFun = new AFTCostFun(instances, $(fitIntercept)) + val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + + val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + /* + The weights vector has three parts: + the first element: Double, log(sigma), the log of scale parameter + the second element: Double, intercept of the beta parameter + the third to the end elements: Doubles, regression coefficients vector of the beta parameter + */ + val initialWeights = Vectors.zeros(numFeatures + 2) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialWeights.toBreeze.toDenseVector) + + val weights = { + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + throw new SparkException(msg) + } + + state.x.toArray.clone() + } + + if (handlePersistence) instances.unpersist() + + val coefficients = Vectors.dense(weights.slice(2, weights.length)) + val intercept = weights(1) + val scale = math.exp(weights(0)) + val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + copyValues(model.setParent(this)) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model produced by [[AFTSurvivalRegression]]. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegressionModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.6.0") val intercept: Double, + @Since("1.6.0") val scale: Double) + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + + @Since("1.6.0") + def predictQuantiles(features: Vector): Vector = { + require(hasQuantileProbabilities, + "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array") + // scale parameter for the Weibull distribution of lifetime + val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) + // shape parameter for the Weibull distribution of lifetime + val k = 1 / scale + val quantiles = $(quantileProbabilities).map { + q => lambda * math.exp(math.log(-math.log(1 - q)) / k) + } + Vectors.dense(quantiles) + } + + @Since("1.6.0") + def predict(features: Vector): Double = { + math.exp(BLAS.dot(coefficients, features) + intercept) + } + + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema) + val predictUDF = udf { features: Vector => predict(features) } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) + .setParent(parent) + } +} + +/** + * AFTAggregator computes the gradient and loss for a AFT loss function, + * as used in AFT survival regression for samples in sparse or dense vector in a online fashion. + * + * The loss function and likelihood function under the AFT model based on: + * Lawless, J. F., Statistical Models and Methods for Lifetime Data, + * New York: John Wiley & Sons, Inc. 2003. + * + * Two AFTAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * Given the values of the covariates x^{'}, for random lifetime t_{i} of subjects i = 1, ..., n, + * with possible right-censoring, the likelihood function under the AFT model is given as + * {{{ + * L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} + * }}} + * Where \delta_{i} is the indicator of the event has occurred i.e. uncensored or not. + * Using \epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}, the log-likelihood function + * assumes the form + * {{{ + * \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+ + * \delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] + * }}} + * Where S_{0}(\epsilon_{i}) is the baseline survivor function, + * and f_{0}(\epsilon_{i}) is corresponding density function. + * + * The most commonly used log-linear survival regression method is based on the Weibull + * distribution of the survival time. The Weibull distribution for lifetime corresponding + * to extreme value distribution for log of the lifetime, + * and the S_{0}(\epsilon) function is + * {{{ + * S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) + * }}} + * the f_{0}(\epsilon_{i}) function is + * {{{ + * f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) + * }}} + * The log-likelihood function for Weibull distribution of lifetime is + * {{{ + * \iota(\beta,\sigma)= + * -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] + * }}} + * Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, + * the loss function we use to optimize is -\iota(\beta,\sigma). + * The gradient functions for \beta and \log\sigma respectively are + * {{{ + * \frac{\partial (-\iota)}{\partial \beta}= + * \sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} + * }}} + * {{{ + * \frac{\partial (-\iota)}{\partial (\log\sigma)}= + * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] + * }}} + * @param weights The log of scale parameter, the intercept and + * regression coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + */ +private class AFTAggregator(weights: BDV[Double], fitIntercept: Boolean) + extends Serializable { + + // beta is the intercept and regression coefficients to the covariates + private val beta = weights.slice(1, weights.length) + // sigma is the scale parameter of the AFT model + private val sigma = math.exp(weights(0)) + + private var totalCnt: Long = 0L + private var lossSum = 0.0 + private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientLogSigmaSum = 0.0 + + def count: Long = totalCnt + + def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt + + // Here we optimize loss function over beta and log(sigma) + def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), + gradientBetaSum/totalCnt.toDouble) + + /** + * Add a new training data to this AFTAggregator, and update the loss and gradient + * of the objective function. + * + * @param data The AFTPoint representation for one data point to be added into this aggregator. + * @return This AFTAggregator object. + */ + def add(data: AFTPoint): this.type = { + + // TODO: Don't create a new xi vector each time. + val xi = if (fitIntercept) { + Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze + } else { + Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze + } + val ti = data.label + val delta = data.censor + val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + + lossSum += math.log(sigma) * delta + lossSum += (math.exp(epsilon) - delta * epsilon) + + // Sanity check (should never occur): + assert(!lossSum.isInfinity, + s"AFTAggregator loss sum is infinity. Error for unknown reason.") + + gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma + gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + + totalCnt += 1 + this + } + + /** + * Merge another AFTAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other AFTAggregator to be merged. + * @return This AFTAggregator object. + */ + def merge(other: AFTAggregator): this.type = { + if (totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + + gradientBetaSum += other.gradientBetaSum + gradientLogSigmaSum += other.gradientLogSigmaSum + } + this + } +} + +/** + * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost. + * It returns the loss and gradient at a particular point (coefficients). + * It's used in Breeze's convex optimization routines. + */ +private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) + extends DiffFunction[BDV[Double]] { + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + + val aftAggregator = data.treeAggregate(new AFTAggregator(coefficients, fitIntercept))( + seqOp = (c, v) => (c, v) match { + case (aggregator, instance) => aggregator.add(instance) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + (aftAggregator.loss, aftAggregator.gradient) + } +} + +/** + * Class that represents the (features, label, censor) of a data point. + * + * @param features List of features for this data point. + * @param label Label for this data point. + * @param censor Indicator of the event has occurred or not. If the value is 1, it means + * the event has occurred i.e. uncensored; otherwise censored. + */ +private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) { + require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0") +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala new file mode 100644 index 000000000000..ca7140a45ea6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, DataFrame} + +class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var datasetUnivariate: DataFrame = _ + @transient var datasetMultivariate: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + datasetUnivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) + datasetMultivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + } + + test("params") { + ParamsSuite.checkParams(new AFTSurvivalRegression) + val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0) + ParamsSuite.checkParams(model) + } + + test("aft survival regression: default params") { + val aftr = new AFTSurvivalRegression + assert(aftr.getLabelCol === "label") + assert(aftr.getFeaturesCol === "features") + assert(aftr.getPredictionCol === "prediction") + assert(aftr.getCensorCol === "censor") + assert(aftr.getFitIntercept) + assert(aftr.getMaxIter === 100) + assert(aftr.getTol === 1E-6) + val model = aftr.fit(datasetUnivariate) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(datasetUnivariate) + .select("label", "prediction") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + + def generateAFTInput( + numFeatures: Int, + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + weibullShape: Double, + weibullScale: Double, + exponentialMean: Double): Seq[AFTPoint] = { + + def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 } + + val weibull = new WeibullGenerator(weibullShape, weibullScale) + weibull.setSeed(seed) + + val exponential = new ExponentialGenerator(exponentialMean) + exponential.setSeed(seed) + + val rnd = new Random(seed) + val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble())) + + x.foreach { v => + var i = 0 + val len = v.length + while (i < len) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } + val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) } + + y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) } + } + + test("aft survival regression with univariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- data$V1 + censor <- data$V2 + label <- data$V3 + sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.759 0.4141 4.247 2.16e-05 + features -0.039 0.0735 -0.531 5.96e-01 + Log(scale) 0.344 0.0379 9.073 1.16e-19 + + Scale= 1.41 + + Weibull distribution + Loglik(model)= -1152.2 Loglik(intercept only)= -1152.3 + Chisq= 0.28 on 1 degrees of freedom, p= 0.6 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.039) + val interceptR = 1.759 + val scaleR = 1.41 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + + testdata <- list(features=6.559282795753792) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.494763 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.1879174 2.6801195 14.5779394 + */ + val features = Vectors.dense(6.559282795753792) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.494763 + val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetUnivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression with multivariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.9206 0.1057 18.171 8.78e-74 + feature1 -0.0844 0.0611 -1.381 1.67e-01 + feature2 0.0677 0.0468 1.447 1.48e-01 + Log(scale) -0.0236 0.0436 -0.542 5.88e-01 + + Scale= 0.977 + + Weibull distribution + Loglik(model)= -1070.7 Loglik(intercept only)= -1072.7 + Chisq= 3.91 on 2 degrees of freedom, p= 0.14 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.0844, 0.0677) + val interceptR = 1.9206 + val scaleR = 0.977 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.761219 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.5287044 3.3285858 10.7517072 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.761219 + val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression w/o intercept") { + val trainer = new AFTSurvivalRegression().setFitIntercept(false) + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + feature1 0.896 0.0685 13.1 3.93e-39 + feature2 -0.709 0.0522 -13.6 5.78e-42 + Log(scale) 0.420 0.0401 10.5 1.23e-25 + + Scale= 1.52 + + Weibull distribution + Loglik(model)= -1292.4 Loglik(intercept only)= -1072.7 + Chisq= -439.57 on 1 degrees of freedom, p= 1 + Number of Newton-Raphson Iterations: 6 + n= 1000 + */ + val coefficientsR = Vectors.dense(0.896, -0.709) + val interceptR = 0.0 + val scaleR = 1.52 + + assert(model.intercept === interceptR) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 44.54465 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 1.452103 25.506077 158.428600 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 44.54465 + val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } +} From d009da2f5c803f3b7344c96abbfcf3ecef2f5ad2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:05:20 -0700 Subject: [PATCH 0674/1824] [SPARK-10682] [GRAPHX] Remove Bagel test suites. Bagel has been deprecated and we haven't done any changes to it. There is no need to run those tests. This should speed up tests by 1 min. Author: Reynold Xin Closes #8807 from rxin/SPARK-10682. --- bagel/src/test/resources/log4j.properties | 27 ----- .../org/apache/spark/bagel/BagelSuite.scala | 113 ------------------ 2 files changed, 140 deletions(-) delete mode 100644 bagel/src/test/resources/log4j.properties delete mode 100644 bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties deleted file mode 100644 index edbecdae9209..000000000000 --- a/bagel/src/test/resources/log4j.properties +++ /dev/null @@ -1,27 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala deleted file mode 100644 index fb10d734ac74..000000000000 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.scalatest.{BeforeAndAfter, Assertions} -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message[String] with Serializable - -class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - } - - test("halting by voting") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("halting by message silence") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) - val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - Array[TestMessage]() - } - (new TestVertex(self.active, self.age + 1), msgsOut) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("large number of iterations") { - // This tests whether jobs with a large number of iterations finish in a reasonable time, - // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(30 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } - - test("using non-default persistence level") { - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 20 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } -} From 93c7650ab60a839a9cbe8b4ea1d5eda93e53ebe0 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Thu, 17 Sep 2015 22:25:24 -0700 Subject: [PATCH 0675/1824] [SPARK-9522] [SQL] SparkSubmit process can not exit if kill application when HiveThriftServer was starting When we start HiveThriftServer, we will start SparkContext first, then start HiveServer2, if we kill application while HiveServer2 is starting then SparkContext will stop successfully, but SparkSubmit process can not exit. Author: linweizhong Closes #7853 from Sephiroth-Lin/SPARK-9522. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/sql/hive/thriftserver/HiveThriftServer2.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9c3218719f7f..ebd8e946ee7a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -97,7 +97,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - private val stopped: AtomicBoolean = new AtomicBoolean(false) + private[spark] val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { if (stopped.get()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index dd9fef9206d0..a0643cec0fb7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -93,6 +93,12 @@ object HiveThriftServer2 extends Logging { } else { None } + // If application was killed before HiveThriftServer2 start successfully then SparkSubmit + // process can not exit, so check whether if SparkContext was stopped. + if (SparkSQLEnv.sparkContext.stopped.get()) { + logError("SparkContext has stopped even if HiveServer2 has started, so exit") + System.exit(-1) + } } catch { case e: Exception => logError("Error starting HiveThriftServer2", e) From 9a56dcdf7f19c9f7f913a2ce9bc981cb43a113c5 Mon Sep 17 00:00:00 2001 From: Felix Bechstein Date: Thu, 17 Sep 2015 22:42:46 -0700 Subject: [PATCH 0676/1824] docs/running-on-mesos.md: state default values in default column This PR simply uses the default value column for defaults. Author: Felix Bechstein Closes #8810 from felixb/fix_mesos_doc. --- docs/running-on-mesos.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 247e6ecfbdb8..1814fb32ed8a 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -332,21 +332,21 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.principal - Framework principal to authenticate to Mesos + (none) Set the principal with which Spark framework will use to authenticate with Mesos. spark.mesos.secret - Framework secret to authenticate to Mesos + (none)/td> Set the secret with which Spark framework will use to authenticate with Mesos. spark.mesos.role - Role for the Spark framework + * Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. @@ -354,7 +354,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.constraints - Attribute based constraints to be matched against when accepting resource offers. + (none) Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes.
      From 74d8f7dda82c3a16348f3ff22da83203e0b7f708 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:46:13 -0700 Subject: [PATCH 0677/1824] Added tag to documentation. --- docs/running-on-mesos.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 1814fb32ed8a..330c159c67bc 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -346,7 +346,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.role - * + * Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. From e3b5d6cb29e0f983fcc55920619e6433298955f5 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Fri, 18 Sep 2015 00:43:02 -0700 Subject: [PATCH 0678/1824] [SPARK-10684] [SQL] StructType.interpretedOrdering need not to be serialized Kryo fails with buffer overflow even with max value (2G). {noformat} org.apache.spark.SparkException: Kryo serialization failed: Buffer overflow. Available: 0, required: 1 Serialization trace: containsChild (org.apache.spark.sql.catalyst.expressions.BoundReference) child (org.apache.spark.sql.catalyst.expressions.SortOrder) array (scala.collection.mutable.ArraySeq) ordering (org.apache.spark.sql.catalyst.expressions.InterpretedOrdering) interpretedOrdering (org.apache.spark.sql.types.StructType) schema (org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema). To avoid this, increase spark.kryoserializer.buffer.max value. at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:263) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:240) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) {noformat} Author: navis.ryu Closes #8808 from navis/SPARK-10684. --- .../main/scala/org/apache/spark/sql/types/StructType.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d8968ef80639..b29cf22dcb58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -305,7 +305,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru f(this) || fields.exists(field => field.dataType.existsRecursively(f)) } - private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) + @transient + private[sql] lazy val interpretedOrdering = + InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } object StructType extends AbstractDataType { From 20fd35dfd1ac402b622604e7bbedcc53a580b0a2 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Fri, 18 Sep 2015 08:22:38 -0700 Subject: [PATCH 0679/1824] [SPARK-10451] [SQL] Prevent unnecessary serializations in InMemoryColumnarTableScan Many of the fields in InMemoryColumnar scan and InMemoryRelation can be made transient. This reduces my 1000ms job to abt 700 ms . The task size reduces from 2.8 mb to ~1300kb Author: Yash Datta Closes #8604 from saucam/serde. --- .../columnar/InMemoryColumnarTableScan.scala | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 66d429bc0619..d7e145f9c2bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -48,10 +48,10 @@ private[sql] case class InMemoryRelation( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan, + @transient child: SparkPlan, tableName: Option[String])( - private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null, + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null, + @transient private var _statistics: Statistics = null, private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { @@ -62,7 +62,7 @@ private[sql] case class InMemoryRelation( _batchStats } - val partitionStatistics = new PartitionStatistics(output) + @transient val partitionStatistics = new PartitionStatistics(output) private def computeSizeInBytes = { val sizeOfRow: Expression = @@ -196,7 +196,7 @@ private[sql] case class InMemoryRelation( private[sql] case class InMemoryColumnarTableScan( attributes: Seq[Attribute], predicates: Seq[Expression], - relation: InMemoryRelation) + @transient relation: InMemoryRelation) extends LeafNode { override def output: Seq[Attribute] = attributes @@ -205,7 +205,7 @@ private[sql] case class InMemoryColumnarTableScan( // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - val buildFilter: PartialFunction[Expression, Expression] = { + @transient val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -268,16 +268,23 @@ private[sql] case class InMemoryColumnarTableScan( readBatches.setValue(0) } - relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => - val partitionFilter = newPredicate( - partitionFilters.reduceOption(And).getOrElse(Literal(true)), - relation.partitionStatistics.schema) + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val schema = relation.partitionStatistics.schema + val schemaIndex = schema.zipWithIndex + val relOutput = relation.output + val buffers = relation.cachedColumnBuffers + + buffers.mapPartitions { cachedBatchIterator => + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + schema) // Find the ordinals and data types of the requested columns. If none are requested, use the // narrowest (the field with minimum default element size). val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { val (narrowestOrdinal, narrowestDataType) = - relation.output.zipWithIndex.map { case (a, ordinal) => + relOutput.zipWithIndex.map { case (a, ordinal) => ordinal -> a.dataType } minBy { case (_, dataType) => ColumnType(dataType).defaultSize @@ -285,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan( Seq(narrowestOrdinal) -> Seq(narrowestDataType) } else { attributes.map { a => - relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType }.unzip } @@ -296,7 +303,7 @@ private[sql] case class InMemoryColumnarTableScan( // Build column accessors val columnAccessors = requestedColumnIndices.map { batchColumnIndex => ColumnAccessor( - relation.output(batchColumnIndex).dataType, + relOutput(batchColumnIndex).dataType, ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex))) } @@ -328,7 +335,7 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { + def statsString: String = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) s"${a.name}: $value" From 35e8ab939000d4a1a01c1af4015c25ff6f4013a3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 18 Sep 2015 09:53:52 -0700 Subject: [PATCH 0680/1824] [SPARK-10615] [PYSPARK] change assertEquals to assertEqual As ```assertEquals``` is deprecated, so we need to change ```assertEquals``` to ```assertEqual``` for existing python unit tests. Author: Yanbo Liang Closes #8814 from yanboliang/spark-10615. --- python/pyspark/ml/tests.py | 16 +-- python/pyspark/mllib/tests.py | 162 +++++++++++++++--------------- python/pyspark/sql/tests.py | 18 ++-- python/pyspark/streaming/tests.py | 2 +- 4 files changed, 99 insertions(+), 99 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b892318f50bd..648fa8858fba 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -182,7 +182,7 @@ def test_params(self): self.assertEqual(testParams.getMaxIter(), 10) testParams.setMaxIter(100) self.assertTrue(testParams.isSet(maxIter)) - self.assertEquals(testParams.getMaxIter(), 100) + self.assertEqual(testParams.getMaxIter(), 100) self.assertTrue(testParams.hasParam(inputCol)) self.assertFalse(testParams.hasDefault(inputCol)) @@ -195,7 +195,7 @@ def test_params(self): testParams._setDefault(seed=41) testParams.setSeed(43) - self.assertEquals( + self.assertEqual( testParams.explainParams(), "\n".join(["inputCol: input column name (undefined)", "maxIter: max number of iterations (>= 0) (default: 10, current: 100)", @@ -264,23 +264,23 @@ def test_ngram(self): self.assertEqual(ngram0.getInputCol(), "input") self.assertEqual(ngram0.getOutputCol(), "output") transformedDF = ngram0.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) def test_stopwordsremover(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") # Default - self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getInputCol(), "input") transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["panda"]) + self.assertEqual(transformedDF.head().output, ["panda"]) # Custom stopwords = ["panda"] stopWordRemover.setStopWords(stopwords) - self.assertEquals(stopWordRemover.getInputCol(), "input") - self.assertEquals(stopWordRemover.getStopWords(), stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a"]) + self.assertEqual(transformedDF.head().output, ["a"]) class HasInducedError(Params): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 636f9a06cab7..96cf13495aa9 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -166,13 +166,13 @@ def test_dot(self): [1., 2., 3., 4.], [1., 2., 3., 4.]]) arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEquals(10.0, sv.dot(dv)) + self.assertEqual(10.0, sv.dot(dv)) self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEquals(30.0, dv.dot(dv)) + self.assertEqual(30.0, dv.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEquals(30.0, lst.dot(dv)) + self.assertEqual(30.0, lst.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEquals(7.0, sv.dot(arr)) + self.assertEqual(7.0, sv.dot(arr)) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) @@ -181,27 +181,27 @@ def test_squared_distance(self): lst1 = [4, 3, 2, 1] arr = pyarray.array('d', [0, 2, 1, 3]) narr = array([0, 2, 1, 3]) - self.assertEquals(15.0, _squared_distance(sv, dv)) - self.assertEquals(25.0, _squared_distance(sv, lst)) - self.assertEquals(20.0, _squared_distance(dv, lst)) - self.assertEquals(15.0, _squared_distance(dv, sv)) - self.assertEquals(25.0, _squared_distance(lst, sv)) - self.assertEquals(20.0, _squared_distance(lst, dv)) - self.assertEquals(0.0, _squared_distance(sv, sv)) - self.assertEquals(0.0, _squared_distance(dv, dv)) - self.assertEquals(0.0, _squared_distance(lst, lst)) - self.assertEquals(25.0, _squared_distance(sv, lst1)) - self.assertEquals(3.0, _squared_distance(sv, arr)) - self.assertEquals(3.0, _squared_distance(sv, narr)) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) def test_hash(self): v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(hash(v1), hash(v2)) - self.assertEquals(hash(v1), hash(v3)) - self.assertEquals(hash(v2), hash(v3)) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) self.assertFalse(hash(v1) == hash(v4)) self.assertFalse(hash(v2) == hash(v4)) @@ -212,8 +212,8 @@ def test_eq(self): v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(v1, v2) - self.assertEquals(v1, v3) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) self.assertFalse(v2 == v4) self.assertFalse(v1 == v5) self.assertFalse(v1 == v6) @@ -238,13 +238,13 @@ def test_conversion(self): def test_sparse_vector_indexing(self): sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv[0], 0.) - self.assertEquals(sv[3], 2.) - self.assertEquals(sv[1], 1.) - self.assertEquals(sv[2], 0.) - self.assertEquals(sv[-1], 2) - self.assertEquals(sv[-2], 0) - self.assertEquals(sv[-4], 0) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[-1], 2) + self.assertEqual(sv[-2], 0) + self.assertEqual(sv[-4], 0) for ind in [4, -5]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: @@ -255,7 +255,7 @@ def test_matrix_indexing(self): expected = [[0, 6], [1, 8], [4, 10]] for i in range(3): for j in range(2): - self.assertEquals(mat[i, j], expected[i][j]) + self.assertEqual(mat[i, j], expected[i][j]) def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -308,11 +308,11 @@ def test_sparse_matrix(self): # Test sparse matrix creation. sm1 = SparseMatrix( 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEquals(sm1.numRows, 3) - self.assertEquals(sm1.numCols, 4) - self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) self.assertTrue( repr(sm1), 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') @@ -325,13 +325,13 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1[i, j]) + self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() - self.assertEquals(sm1.numRows, smnew.numRows) - self.assertEquals(sm1.numCols, smnew.numCols) + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) self.assertTrue(array_equal(sm1.values, smnew.values)) @@ -339,11 +339,11 @@ def test_sparse_matrix(self): sm1t = SparseMatrix( 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], isTransposed=True) - self.assertEquals(sm1t.numRows, 3) - self.assertEquals(sm1t.numCols, 4) - self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) expected = [ [3, 2, 0, 0], @@ -352,18 +352,18 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertEqual(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) def test_dense_matrix_is_transposed(self): mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEquals(mat1, mat) + self.assertEqual(mat1, mat) expected = [[0, 4], [1, 6], [3, 9]] for i in range(3): for j in range(2): - self.assertEquals(mat1[i, j], expected[i][j]) + self.assertEqual(mat1[i, j], expected[i][j]) self.assertTrue(array_equal(mat1.toArray(), expected)) sm = mat1.toSparse() @@ -412,8 +412,8 @@ def test_kmeans(self): ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", initializationSteps=7, epsilon=1e-4) - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_kmeans_deterministic(self): from pyspark.mllib.clustering import KMeans @@ -443,8 +443,8 @@ def test_gmm(self): clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, maxIterations=10, seed=56) labels = clusters.predict(data).collect() - self.assertEquals(labels[0], labels[1]) - self.assertEquals(labels[2], labels[3]) + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) def test_gmm_deterministic(self): from pyspark.mllib.clustering import GaussianMixture @@ -456,7 +456,7 @@ def test_gmm_deterministic(self): clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, maxIterations=10, seed=63) for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEquals(round(c1, 7), round(c2, 7)) + self.assertEqual(round(c1, 7), round(c2, 7)) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -711,18 +711,18 @@ def test_serialize(self): lil[1, 0] = 1 lil[3, 0] = 2 sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv, _convert_to_vector(lil)) - self.assertEquals(sv, _convert_to_vector(lil.tocsc())) - self.assertEquals(sv, _convert_to_vector(lil.tocoo())) - self.assertEquals(sv, _convert_to_vector(lil.tocsr())) - self.assertEquals(sv, _convert_to_vector(lil.todok())) + self.assertEqual(sv, _convert_to_vector(lil)) + self.assertEqual(sv, _convert_to_vector(lil.tocsc())) + self.assertEqual(sv, _convert_to_vector(lil.tocoo())) + self.assertEqual(sv, _convert_to_vector(lil.tocsr())) + self.assertEqual(sv, _convert_to_vector(lil.todok())) def serialize(l): return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEquals(sv, serialize(lil)) - self.assertEquals(sv, serialize(lil.tocsc())) - self.assertEquals(sv, serialize(lil.tocsr())) - self.assertEquals(sv, serialize(lil.todok())) + self.assertEqual(sv, serialize(lil)) + self.assertEqual(sv, serialize(lil.tocsc())) + self.assertEqual(sv, serialize(lil.tocsr())) + self.assertEqual(sv, serialize(lil.todok())) def test_dot(self): from scipy.sparse import lil_matrix @@ -730,7 +730,7 @@ def test_dot(self): lil[1, 0] = 1 lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEquals(10.0, dv.dot(lil)) + self.assertEqual(10.0, dv.dot(lil)) def test_squared_distance(self): from scipy.sparse import lil_matrix @@ -739,8 +739,8 @@ def test_squared_distance(self): lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEquals(15.0, dv.squared_distance(lil)) - self.assertEquals(15.0, sv.squared_distance(lil)) + self.assertEqual(15.0, dv.squared_distance(lil)) + self.assertEqual(15.0, sv.squared_distance(lil)) def scipy_matrix(self, size, values): """Create a column SciPy matrix from a dictionary of values""" @@ -759,8 +759,8 @@ def test_clustering(self): self.scipy_matrix(3, {2: 1.1}) ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -984,12 +984,12 @@ def test_word2vec_setters(self): .setNumIterations(10) \ .setSeed(1024) \ .setMinCount(3) - self.assertEquals(model.vectorSize, 2) + self.assertEqual(model.vectorSize, 2) self.assertTrue(model.learningRate < 0.02) - self.assertEquals(model.numPartitions, 2) - self.assertEquals(model.numIterations, 10) - self.assertEquals(model.seed, 1024) - self.assertEquals(model.minCount, 3) + self.assertEqual(model.numPartitions, 2) + self.assertEqual(model.numIterations, 10) + self.assertEqual(model.seed, 1024) + self.assertEqual(model.minCount, 3) def test_word2vec_get_vectors(self): data = [ @@ -1002,7 +1002,7 @@ def test_word2vec_get_vectors(self): ["a"] ] model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEquals(len(model.getVectors()), 3) + self.assertEqual(len(model.getVectors()), 3) class StandardScalerTests(MLlibTestCase): @@ -1044,8 +1044,8 @@ def test_model_params(self): """Test that the model params are set correctly""" stkm = StreamingKMeans() stkm.setK(5).setDecayFactor(0.0) - self.assertEquals(stkm._k, 5) - self.assertEquals(stkm._decayFactor, 0.0) + self.assertEqual(stkm._k, 5) + self.assertEqual(stkm._decayFactor, 0.0) # Model not set yet. self.assertIsNone(stkm.latestModel()) @@ -1053,9 +1053,9 @@ def test_model_params(self): stkm.setInitialCenters( centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEquals( + self.assertEqual( stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) def test_accuracy_for_single_center(self): """Test that parameters obtained are correct for a single center.""" @@ -1070,7 +1070,7 @@ def test_accuracy_for_single_center(self): self.ssc.start() def condition(): - self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) return True self._eventually(condition, catch_assertions=True) @@ -1114,7 +1114,7 @@ def test_trainOn_model(self): def condition(): finalModel = stkm.latestModel() self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) return True self._eventually(condition, catch_assertions=True) @@ -1141,7 +1141,7 @@ def update(rdd): self.ssc.start() def condition(): - self.assertEquals(result, [[0], [1], [2], [3]]) + self.assertEqual(result, [[0], [1], [2], [3]]) return True self._eventually(condition, catch_assertions=True) @@ -1263,7 +1263,7 @@ def test_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(models), len(input_batches)) + self.assertEqual(len(models), len(input_batches)) return True # We want all batches to finish for this test. @@ -1297,7 +1297,7 @@ def test_predictions(self): self.ssc.start() def condition(): - self.assertEquals(len(true_predicted), len(input_batches)) + self.assertEqual(len(true_predicted), len(input_batches)) return True self._eventually(condition, catch_assertions=True) @@ -1400,7 +1400,7 @@ def test_parameter_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(model_weights), len(batches)) + self.assertEqual(len(model_weights), len(batches)) return True # We want all batches to finish for this test. @@ -1433,7 +1433,7 @@ def test_prediction(self): self.ssc.start() def condition(): - self.assertEquals(len(samples), len(batches)) + self.assertEqual(len(samples), len(batches)) return True # We want all batches to finish for this test. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f2172b7a27d8..3e680f1030a7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -157,7 +157,7 @@ class DataTypeTests(unittest.TestCase): def test_data_type_eq(self): lt = LongType() lt2 = pickle.loads(pickle.dumps(LongType())) - self.assertEquals(lt, lt2) + self.assertEqual(lt, lt2) # regression test for SPARK-7978 def test_decimal_type(self): @@ -393,7 +393,7 @@ def test_infer_nested_schema(self): CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) df = self.sqlCtx.inferSchema(rdd) - self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] @@ -403,7 +403,7 @@ def test_create_dataframe_from_objects(self): def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") - self.assertEquals(Row(col=None), df.first()) + self.assertEqual(Row(col=None), df.first()) def test_apply_schema(self): from datetime import date, datetime @@ -519,14 +519,14 @@ def test_apply_schema_with_udt(self): StructField("point", ExamplePointUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT @@ -554,14 +554,14 @@ def test_parquet_with_udt(self): df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_column_operators(self): ci = self.df.key @@ -826,8 +826,8 @@ def test_infer_long_type(self): output_dir = os.path.join(self.tempdir.name, "infer_long_type") df.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) - self.assertEquals('a', df1.first().f1) - self.assertEquals(100000000000000, df1.first().f2) + self.assertEqual('a', df1.first().f1) + self.assertEqual(100000000000000, df1.first().f2) self.assertEqual(_infer_type(1), LongType()) self.assertEqual(_infer_type(2**10), LongType()) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index cfea95b0dec7..e4e56fff3b3f 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -693,7 +693,7 @@ def check_output(n): # Verify that getActiveOrCreate() returns active context self.setupCalled = False - self.assertEquals(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) + self.assertEqual(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) self.assertFalse(self.setupCalled) # Verify that getActiveOrCreate() uses existing SparkContext From 00a2911c5bea67a1a4796fb1d6fd5d0a8ee79001 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 12:19:08 -0700 Subject: [PATCH 0681/1824] [SPARK-10540] Fixes flaky all-data-type test This PR breaks the original test case into multiple ones (one test case for each data type). In this way, test failure output can be much more readable. Within each test case, we build a table with two columns, one of them is for the data type to test, the other is an "index" column, which is used to sort the DataFrame and workaround [SPARK-10591] [1] [1]: https://issues.apache.org/jira/browse/SPARK-10591 Author: Cheng Lian Closes #8768 from liancheng/spark-10540/test-all-data-types. --- .../sql/sources/hadoopFsRelationSuites.scala | 109 +++++++----------- 1 file changed, 43 insertions(+), 66 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 8ffcef85668d..d7504936d90e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -100,80 +100,57 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - ignore("test all data types") { - withTempPath { file => - // Create the schema. - val struct = - StructType( - StructField("f1", FloatType, true) :: - StructField("f2", ArrayType(BooleanType), true) :: Nil) - // TODO: add CalendarIntervalType to here once we can save it out. - val dataTypes = - Seq( - StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct, - new MyDenseVectorUDT()) - val fields = dataTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, nullable = true) - } - val schema = StructType(fields) - - // Generate data at the driver side. We need to materialize the data first and then - // create RDD. - val maybeDataGenerator = - RandomDataGenerator.forType( - dataType = schema, + private val supportedDataTypes = Seq( + StringType, BinaryType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new MyDenseVectorUDT() + ).filter(supportsDataType) + + for (dataType <- supportedDataTypes) { + test(s"test all data types - $dataType") { + withTempPath { file => + val path = file.getCanonicalPath + + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, nullable = true, - seed = Some(System.nanoTime())) - val dataGenerator = - maybeDataGenerator - .getOrElse(fail(s"Failed to create data generator for schema $schema")) - val data = (1 to 10).map { i => - dataGenerator.apply() match { - case row: Row => row - case null => Row.fromSeq(Seq.fill(schema.length)(null)) - case other => - fail(s"Row or null is expected to be generated, " + - s"but a ${other.getClass.getCanonicalName} is generated.") + seed = Some(System.nanoTime()) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") } - } - // Create a DF for the schema with random data. - val rdd = sqlContext.sparkContext.parallelize(data, 10) - val df = sqlContext.createDataFrame(rdd, schema) + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - // All columns that have supported data types of this source. - val supportedColumns = schema.fields.collect { - case StructField(name, dataType, _, _) if supportsDataType(dataType) => name - } - val selectedColumns = util.Random.shuffle(supportedColumns.toSeq) - - val dfToBeSaved = df.selectExpr(selectedColumns: _*) - - // Save the data out. - dfToBeSaved - .write - .format(dataSourceName) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .save(file.getCanonicalPath) + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) - val loadedDF = - sqlContext + val loadedDF = sqlContext .read .format(dataSourceName) - .schema(dfToBeSaved.schema) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .load(file.getCanonicalPath) - .selectExpr(selectedColumns: _*) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") - // Read the data back. - checkAnswer( - loadedDF, - dfToBeSaved - ) + checkAnswer(loadedDF, df) + } } } From c6f8135ee52202bd86adb090ab631e80330ea4df Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 18 Sep 2015 13:20:13 -0700 Subject: [PATCH 0682/1824] [SPARK-10539] [SQL] Project should not be pushed down through Intersect or Except #8742 Intersect and Except are both set operators and they use the all the columns to compare equality between rows. When pushing their Project parent down, the relations they based on would change, therefore not an equivalent transformation. JIRA: https://issues.apache.org/jira/browse/SPARK-10539 I added some comments based on the fix of https://github.com/apache/spark/pull/8742. Author: Yijie Shen Author: Yin Huai Closes #8823 from yhuai/fix_set_optimization. --- .../sql/catalyst/optimizer/Optimizer.scala | 37 ++++++++++--------- .../optimizer/SetOperationPushDownSuite.scala | 23 ++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 9 +++++ 3 files changed, 39 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 648a65e7c0eb..324f40a051c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,7 +85,22 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union, Intersect or Except. + * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Operations that are safe to pushdown are listed as follows. + * Union: + * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is + * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, + * we will not be able to pushdown Projections. + * + * Intersect: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. + * + * Except: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. */ object SetOperationPushDown extends Rule[LogicalPlan] { @@ -122,40 +137,26 @@ object SetOperationPushDown extends Rule[LogicalPlan] { Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into union + // Push down projection through UNION ALL case Project(projectList, u @ Union(left, right)) => val rewrites = buildRewrites(u) Union( Project(projectList, left), Project(projectList.map(pushToRight(_, rewrites)), right)) - // Push down filter into intersect + // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => val rewrites = buildRewrites(i) Intersect( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into intersect - case Project(projectList, i @ Intersect(left, right)) => - val rewrites = buildRewrites(i) - Intersect( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) - - // Push down filter into except + // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => val rewrites = buildRewrites(e) Except( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - - // Push down projection into except - case Project(projectList, e @ Except(left, right)) => - val rewrites = buildRewrites(e) - Except( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 49c979bc7d72..3fca47a023dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -60,23 +60,22 @@ class SetOperationPushDownSuite extends PlanTest { comparePlans(exceptOptimized, exceptCorrectAnswer) } - test("union/intersect/except: project to each side") { + test("union: project to each side") { val unionQuery = testUnion.select('a) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze - val intersectCorrectAnswer = - Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze - val exceptCorrectAnswer = - Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze - - comparePlans(unionOptimized, unionCorrectAnswer) - comparePlans(intersectOptimized, intersectCorrectAnswer) - comparePlans(exceptOptimized, exceptCorrectAnswer) } + comparePlans(intersectOptimized, intersectQuery.analyze) + comparePlans(exceptOptimized, exceptQuery.analyze) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c167999af580..1370713975f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -907,4 +907,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) } } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } } From 3a22b1004f527d54d399dd0225cd7f2f8ffad9c5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Sep 2015 13:47:14 -0700 Subject: [PATCH 0683/1824] [SPARK-10449] [SQL] Don't merge decimal types with incompatable precision or scales From JIRA: Schema merging should only handle struct fields. But currently we also reconcile decimal precision and scale information. Author: Holden Karau Closes #8634 from holdenk/SPARK-10449-dont-merge-different-precision. --- .../org/apache/spark/sql/types/StructType.scala | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b29cf22dcb58..d6b436724b2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -373,10 +373,19 @@ object StructType extends AbstractDataType { StructType(newFields) case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType( - max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), - max(leftScale, rightScale)) + DecimalType.Fixed(rightPrecision, rightScale)) => + if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { + DecimalType(leftPrecision, leftScale) + } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") + } else if (leftPrecision != rightPrecision) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision") + } else { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"scala $leftScale and $rightScale") + } case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt From 348d7c9a93dd00d3d1859342a8eb0aea2e77f597 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Sep 2015 13:48:41 -0700 Subject: [PATCH 0684/1824] [SPARK-9808] Remove hash shuffle file consolidation. Author: Reynold Xin Closes #8812 from rxin/SPARK-9808-1. --- .../shuffle/FileShuffleBlockResolver.scala | 178 ++---------------- .../apache/spark/storage/BlockManager.scala | 9 - .../org/apache/spark/storage/DiskStore.scala | 3 - .../hash/HashShuffleManagerSuite.scala | 110 ----------- docs/configuration.md | 10 - .../shuffle/ExternalShuffleBlockResolver.java | 4 - project/MimaExcludes.scala | 4 + 7 files changed, 17 insertions(+), 301 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index c057de9b3f4d..d9902f96dfd4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,9 +17,7 @@ package org.apache.spark.shuffle -import java.io.File import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ @@ -28,10 +26,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -43,24 +39,7 @@ private[spark] trait ShuffleWriterGroup { /** * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file - * per reducer (this set of files is called a ShuffleFileGroup). - * - * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle - * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer - * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle - * files, it releases them for another task. - * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: - * - shuffleId: The unique id given to the entire shuffle stage. - * - bucketId: The id of the output partition (i.e., reducer id) - * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a - * time owns a particular fileId, and this id is returned to a pool when the task finishes. - * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) - * that specifies where in a given file the actual block data is located. - * - * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping - * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for - * each block stored in each file. In order to find the location of a shuffle block, we search the - * files within a ShuffleFileGroups associated with the block's reducer. + * per reducer. */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData(). @@ -71,26 +50,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private lazy val blockManager = SparkEnv.get.blockManager - // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. - // TODO: Remove this once the shuffle file consolidation feature is stable. - private val consolidateShuffleFiles = - conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** - * Contains all the state related to a particular shuffle. This includes a pool of unused - * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + * Contains all the state related to a particular shuffle. */ - private class ShuffleState(val numBuckets: Int) { - val nextFileId = new AtomicInteger(0) - val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - + private class ShuffleState(val numReducers: Int) { /** * The mapIds of all map tasks completed on this Executor for this shuffle. - * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. */ val completedMapTasks = new ConcurrentLinkedQueue[Int]() } @@ -104,24 +72,16 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) private val shuffleState = shuffleStates(shuffleId) - private var fileGroup: ShuffleFileGroup = null val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { - fileGroup = getUnusedFileGroup() - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, - writeMetrics) - } - } else { - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => + val writers: Array[DiskBlockObjectWriter] = { + Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. @@ -142,58 +102,14 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { - if (consolidateShuffleFiles) { - if (success) { - val offsets = writers.map(_.fileSegment().offset) - val lengths = writers.map(_.fileSegment().length) - fileGroup.recordMapOutput(mapId, offsets, lengths) - } - recycleFileGroup(fileGroup) - } else { - shuffleState.completedMapTasks.add(mapId) - } - } - - private def getUnusedFileGroup(): ShuffleFileGroup = { - val fileGroup = shuffleState.unusedFileGroups.poll() - if (fileGroup != null) fileGroup else newFileGroup() - } - - private def newFileGroup(): ShuffleFileGroup = { - val fileId = shuffleState.nextFileId.getAndIncrement() - val files = Array.tabulate[File](numBuckets) { bucketId => - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.diskBlockManager.getFile(filename) - } - val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) - shuffleState.allFileGroups.add(fileGroup) - fileGroup - } - - private def recycleFileGroup(group: ShuffleFileGroup) { - shuffleState.unusedFileGroups.add(group) + shuffleState.completedMapTasks.add(mapId) } } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - if (consolidateShuffleFiles) { - // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(blockId.shuffleId) - val iter = shuffleState.allFileGroups.iterator - while (iter.hasNext) { - val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) - if (segmentOpt.isDefined) { - val segment = segmentOpt.get - return new FileSegmentManagedBuffer( - transportConf, segment.file, segment.offset, segment.length) - } - } - throw new IllegalStateException("Failed to find shuffle block: " + blockId) - } else { - val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(transportConf, file, 0, file.length) - } + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } /** Remove all the blocks / files and metadata related to a particular shuffle. */ @@ -209,17 +125,9 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups.asScala; - file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks.asScala; - reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } + for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() } logInfo("Deleted all files for shuffle " + shuffleId) true @@ -229,10 +137,6 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { - "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) - } - private def cleanup(cleanupTime: Long) { shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } @@ -241,59 +145,3 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) metadataCleaner.cancel() } } - -private[spark] object FileShuffleBlockResolver { - /** - * A group of shuffle files, one per reducer. - * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. - */ - private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { - private var numBlocks: Int = 0 - - /** - * Stores the absolute index of each mapId in the files of this group. For instance, - * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. - */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() - - /** - * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by - * position in the file. - * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every - * reducer. - */ - private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - - def apply(bucketId: Int): File = files(bucketId) - - def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { - assert(offsets.length == lengths.length) - mapIdToIndex(mapId) = numBlocks - numBlocks += 1 - for (i <- 0 until offsets.length) { - blockOffsetsByReducer(i) += offsets(i) - blockLengthsByReducer(i) += lengths(i) - } - } - - /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ - def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { - val file = files(reducerId) - val blockOffsets = blockOffsetsByReducer(reducerId) - val blockLengths = blockLengthsByReducer(reducerId) - val index = mapIdToIndex.getOrElse(mapId, -1) - if (index >= 0) { - val offset = blockOffsets(index) - val length = blockLengths(index) - Some(new FileSegment(file, offset, length)) - } else { - None - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d31aa68eb695..bca3942f8c55 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -106,15 +106,6 @@ private[spark] class BlockManager( } } - // Check that we're not using external shuffle service with consolidated shuffle files. - if (externalShuffleServiceEnabled - && conf.getBoolean("spark.shuffle.consolidateFiles", false) - && shuffleManager.isInstanceOf[HashShuffleManager]) { - throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" - + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " - + " switch to sort-based shuffle.") - } - var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1f4595628216..feb9533604ff 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -154,9 +154,6 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc override def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) - // If consolidation mode is used With HashShuffleMananger, the physical filename for the block - // is different from blockId.name. So the file returns here will not be exist, thus we avoid to - // delete the whole consolidated file by mistake. if (file.exists()) { file.delete() } else { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala deleted file mode 100644 index 491dc3659e18..000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.{File, FileWriter} - -import scala.language.reflectiveCalls - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.FileShuffleBlockResolver -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} - -class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { - private val testConf = new SparkConf(false) - - private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { - assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) - val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) - assert(expected.offset === segment.getOffset) - assert(expected.length === segment.getLength) - } - - test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { - - val conf = new SparkConf(false) - // reset after EACH object write. This is to ensure that there are bytes appended after - // an object is written. So if the codepaths assume writeObject is end of data, this should - // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. - conf.set("spark.serializer.objectStreamReset", "1") - conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - - sc = new SparkContext("local", "test", conf) - - val shuffleBlockResolver = - SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockResolver] - - val shuffle1 = shuffleBlockResolver.forMapTask(1, 1, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - for (writer <- shuffle1.writers) { - writer.write("test1", "value") - writer.write("test2", "value") - } - for (writer <- shuffle1.writers) { - writer.commitAndClose() - } - - val shuffle1Segment = shuffle1.writers(0).fileSegment() - shuffle1.releaseWriters(success = true) - - val shuffle2 = shuffleBlockResolver.forMapTask(1, 2, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - - for (writer <- shuffle2.writers) { - writer.write("test3", "value") - writer.write("test4", "vlue") - } - for (writer <- shuffle2.writers) { - writer.commitAndClose() - } - val shuffle2Segment = shuffle2.writers(0).fileSegment() - shuffle2.releaseWriters(success = true) - - // Now comes the test : - // Write to shuffle 3; and close it, but before registering it, check if the file lengths for - // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is - // concurrent read and writes happening to the same shuffle group. - - val shuffle3 = shuffleBlockResolver.forMapTask(1, 3, 1, new JavaSerializer(testConf), - new ShuffleWriteMetrics) - for (writer <- shuffle3.writers) { - writer.write("test3", "value") - writer.write("test4", "value") - } - for (writer <- shuffle3.writers) { - writer.commitAndClose() - } - // check before we register. - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffle3.releaseWriters(success = true) - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffleBlockResolver.removeShuffle(1) - } - - def writeToFile(file: File, numBytes: Int) { - val writer = new FileWriter(file, true) - for (i <- 0 until numBytes) writer.write(i) - writer.close() - } -} diff --git a/docs/configuration.md b/docs/configuration.md index 1a701f18881f..3700051efb44 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -390,16 +390,6 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec. - - spark.shuffle.consolidateFiles - false - - If set to "true", consolidates intermediate files created during a shuffle. Creating fewer - files can improve filesystem performance for shuffles with large numbers of reduce tasks. It - is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option - might degrade performance on machines with many (>8) cores due to filesystem limitations. - - spark.shuffle.file.buffer 32k diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 79beec4429a9..c5f93bb47f55 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -50,9 +50,6 @@ * of Executors. Each Executor must register its own configuration about where it stores its files * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver. - * - * Executors with shuffle file consolidation are not currently supported, as the index is stored in - * the Executor's memory, unlike the IndexShuffleBlockResolver. */ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); @@ -254,7 +251,6 @@ private void deleteExecutorDirs(String[] dirs) { * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockResolver. */ - // TODO: Support consolidated hash shuffle files private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1c96b0958586..814a11e588ce 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,6 +70,10 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") + ) ++ + Seq( + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") ) case v if v.startsWith("1.5") => Seq( From 8074208fa47fa654c1055c48cfa0d923edeeb04f Mon Sep 17 00:00:00 2001 From: Mingyu Kim Date: Fri, 18 Sep 2015 15:40:58 -0700 Subject: [PATCH 0685/1824] [SPARK-10611] Clone Configuration for each task for NewHadoopRDD This patch attempts to fix the Hadoop Configuration thread safety issue for NewHadoopRDD in the same way SPARK-2546 fixed the issue for HadoopRDD. Author: Mingyu Kim Closes #8763 from mingyukim/mkim/SPARK-10611. --- .../org/apache/spark/rdd/BinaryFileRDD.scala | 5 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 37 ++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 6fec00dcd0d8..aedced7408cd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -34,12 +34,13 @@ private[spark] class BinaryFileRDD[T]( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 174979aaeb23..2872b93b8730 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -44,7 +44,6 @@ private[spark] class NewHadoopPartition( extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = 41 * (41 + rddId) + index } @@ -84,6 +83,27 @@ class NewHadoopRDD[K, V]( @transient protected val jobId = new JobID(jobTrackerId, id) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + + def getConf: Configuration = { + val conf: Configuration = confBroadcast.value.value + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546, SPARK-10611). This + // problem occurs somewhat rarely because most jobs treat the configuration as though it's + // immutable. One solution, implemented here, is to clone the Configuration object. + // Unfortunately, this clone can be very expensive. To avoid unexpected performance + // regressions for workloads and Hadoop versions that do not suffer from these thread-safety + // issues, this cloning is disabled by default. + NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") + new Configuration(conf) + } + } else { + conf + } + } + override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance inputFormat match { @@ -104,7 +124,7 @@ class NewHadoopRDD[K, V]( val iter = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value + val conf = getConf val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) @@ -230,11 +250,15 @@ class NewHadoopRDD[K, V]( super.persist(storageLevel) } - - def getConf: Configuration = confBroadcast.value.value } private[spark] object NewHadoopRDD { + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new Configuration(). + */ + val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. @@ -268,12 +292,13 @@ private[spark] class WholeTextFileRDD( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) From c8149ef2c57f5c47ab97ee8d8d58a216d4bc4156 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 18 Sep 2015 16:23:05 -0700 Subject: [PATCH 0686/1824] [MINOR] [ML] override toString of AttributeGroup This makes equality test failures much more readable. mengxr Author: Eric Liang Author: Eric Liang Closes #8826 from ericl/attrgroupstr. --- .../scala/org/apache/spark/ml/attribute/AttributeGroup.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 457c15830fd3..2c29eeb01a92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -183,6 +183,8 @@ class AttributeGroup private ( sum = 37 * sum + attributes.map(_.toSeq).hashCode sum } + + override def toString: String = toMetadata.toString } /** From 22be2ae147a111e88896f6fb42ed46bbf108a99b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 18:42:20 -0700 Subject: [PATCH 0687/1824] [SPARK-10623] [SQL] Fixes ORC predicate push-down When pushing down a leaf predicate, ORC `SearchArgument` builder requires an extra "parent" predicate (any one among `AND`/`OR`/`NOT`) to wrap the leaf predicate. E.g., to push down `a < 1`, we must build `AND(a < 1)` instead. Fortunately, when actually constructing the `SearchArgument`, the builder will eliminate all those unnecessary wrappers. This PR is based on #8783 authored by zhzhan. I also took the chance to simply `OrcFilters` a little bit to improve readability. Author: Cheng Lian Closes #8799 from liancheng/spark-10623/fix-orc-ppd. --- .../spark/sql/hive/orc/OrcFilters.scala | 56 ++++++++----------- .../spark/sql/hive/orc/OrcQuerySuite.scala | 30 ++++++++++ 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index b3d9f7f71a27..27193f54d3a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -31,11 +31,13 @@ import org.apache.spark.sql.sources._ * and cannot be used anymore. */ private[orc] object OrcFilters extends Logging { - def createFilter(expr: Array[Filter]): Option[SearchArgument] = { - expr.reduceOption(And).flatMap { conjunction => - val builder = SearchArgumentFactory.newBuilder() - buildSearchArgument(conjunction, builder).map(_.build()) - } + def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + for { + // Combines all filters with `And`s to produce a single conjunction predicate + conjunction <- filters.reduceOption(And) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) + } yield builder.build() } private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { @@ -102,46 +104,32 @@ private[orc] object OrcFilters extends Logging { negate <- buildSearchArgument(child, builder.startNot()) } yield negate.end() - case EqualTo(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.equals(attribute, _)) + case EqualTo(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().equals(attribute, value).end()) - case EqualNullSafe(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.nullSafeEquals(attribute, _)) + case EqualNullSafe(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().nullSafeEquals(attribute, value).end()) - case LessThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThan(attribute, _)) + case LessThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThan(attribute, value).end()) - case LessThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThanEquals(attribute, _)) + case LessThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThanEquals(attribute, value).end()) - case GreaterThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThanEquals(attribute, _).end()) + case GreaterThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThanEquals(attribute, value).end()) - case GreaterThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThan(attribute, _).end()) + case GreaterThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThan(attribute, value).end()) case IsNull(attribute) => - Some(builder.isNull(attribute)) + Some(builder.startAnd().isNull(attribute).end()) case IsNotNull(attribute) => Some(builder.startNot().isNull(attribute).end()) - case In(attribute, values) => - Option(values) - .filter(_.forall(isSearchableLiteral)) - .map(builder.in(attribute, _)) + case In(attribute, values) if values.forall(isSearchableLiteral) => + Some(builder.startAnd().in(attribute, values.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 8bc33fcf5d90..5eb39b112970 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -344,4 +344,34 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("SPARK-10623 Enable ORC PPD") { + withTempPath { dir => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + import testImplicits._ + + val path = dir.getCanonicalPath + sqlContext.range(10).coalesce(1).write.orc(path) + val df = sqlContext.read.orc(path) + + def checkPredicate(pred: Column, answer: Seq[Long]): Unit = { + checkAnswer(df.where(pred), answer.map(Row(_))) + } + + checkPredicate('id === 5, Seq(5L)) + checkPredicate('id <=> 5, Seq(5L)) + checkPredicate('id < 5, 0L to 4L) + checkPredicate('id <= 5, 0L to 5L) + checkPredicate('id > 5, 6L to 9L) + checkPredicate('id >= 5, 5L to 9L) + checkPredicate('id.isNull, Seq.empty[Long]) + checkPredicate('id.isNotNull, 0L to 9L) + checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L)) + checkPredicate('id > 0 && 'id < 3, 1L to 2L) + checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L)) + checkPredicate(!('id > 3), 0L to 3L) + checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L)) + } + } + } } From 7ff8d68cc19299e16dedfd819b9e96480fa6cf44 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 18 Sep 2015 23:58:25 -0700 Subject: [PATCH 0688/1824] [SPARK-10474] [SQL] Aggregation fails to allocate memory for pointer array When `TungstenAggregation` hits memory pressure, it switches from hash-based to sort-based aggregation in-place. However, in the process we try to allocate the pointer array for writing to the new `UnsafeExternalSorter` *before* actually freeing the memory from the hash map. This lead to the following exception: ``` java.io.IOException: Could not acquire 65536 bytes of memory at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.initializeForWriting(UnsafeExternalSorter.java:169) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill(UnsafeExternalSorter.java:220) at org.apache.spark.sql.execution.UnsafeKVExternalSorter.(UnsafeKVExternalSorter.java:126) at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter(UnsafeFixedWidthAggregationMap.java:257) at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.switchToSortBasedAggregation(TungstenAggregationIterator.scala:435) ``` Author: Andrew Or Closes #8827 from andrewor14/allocate-pointer-array. --- .../unsafe/sort/UnsafeExternalSorter.java | 14 +++++- .../sql/execution/UnsafeKVExternalSorter.java | 8 ++- .../UnsafeFixedWidthAggregationMapSuite.scala | 49 ++++++++++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index fc364e0a895b..14b6aafdea7d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -159,7 +159,7 @@ public BoxedUnit apply() { /** * Allocates new sort data structures. Called when creating the sorter and after each spill. */ - private void initializeForWriting() throws IOException { + public void initializeForWriting() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); final long pointerArrayMemory = UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); @@ -187,6 +187,14 @@ public void closeCurrentPage() { * Sort and spill the current records in response to memory pressure. */ public void spill() throws IOException { + spill(true); + } + + /** + * Sort and spill the current records in response to memory pressure. + * @param shouldInitializeForWriting whether to allocate memory for writing after the spill + */ + public void spill(boolean shouldInitializeForWriting) throws IOException { assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), @@ -217,7 +225,9 @@ public void spill() throws IOException { // written to disk. This also counts the space needed to store the sorter's pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - initializeForWriting(); + if (shouldInitializeForWriting) { + initializeForWriting(); + } } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 7db6b7ff50f2..b81f67a16b81 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -85,6 +85,7 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, // We will use the number of elements in the map as the initialSize of the // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize, // we will use 1 as its initial size if the map is empty. + // TODO: track pointer array memory used by this in-memory sorter! final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); @@ -123,8 +124,13 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, pageSizeBytes, inMemSorter); - sorter.spill(); + // Note: This spill doesn't actually release any memory, so if we try to allocate a new + // pointer array immediately after the spill then we may fail to acquire sufficient space + // for it (SPARK-10474). For this reason, we must initialize for writing explicitly *after* + // we have actually freed memory from our map. + sorter.spill(false /* initialize for writing */); map.free(); + sorter.initializeForWriting(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index d1f0b2b1fc52..ada4d42f991c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,9 +23,10 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -325,7 +326,7 @@ class UnsafeFixedWidthAggregationMapSuite // At here, we also test if copy is correct. iter.getKey.copy() iter.getValue.copy() - count += 1; + count += 1 } // 1 record was from the map and 4096 records were explicitly inserted. @@ -333,4 +334,48 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } + + testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") { + val smm = ShuffleMemoryManager.createForTesting(65536) + val pageSize = 4096 + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + smm, + 128, // initial capacity + pageSize, + false // disable perf metrics + ) + + // Insert into the map until we've run out of space + val rand = new Random(42) + var hasSpace = true + while (hasSpace) { + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + if (buf == null) { + hasSpace = false + } else { + buf.setInt(0, str.length) + } + } + + // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte + assert(smm.tryToAcquire(1) === 0) + + // Convert the map into a sorter. This used to fail before the fix for SPARK-10474 + // because we would try to acquire space for the in-memory sorter pointer array before + // actually releasing the pages despite having spilled all of them. + var sorter: UnsafeKVExternalSorter = null + try { + sorter = map.destructAndCreateExternalSorter() + } finally { + if (sorter != null) { + sorter.cleanupResources() + } + } + } + } From d507f9c0b7f7a524137a694ed6443747aaf90463 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 19 Sep 2015 01:59:36 -0700 Subject: [PATCH 0689/1824] [SPARK-10584] [SQL] [DOC] Documentation about the compatible Hive version is wrong. In Spark 1.5.0, Spark SQL is compatible with Hive 0.12.0 through 1.2.1 but the documentation is wrong. /CC yhuai Author: Kousuke Saruta Closes #8776 from sarutak/SPARK-10584-2. --- docs/sql-programming-guide.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a0b911d20724..82d4243cc6b2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1954,7 +1954,7 @@ without the need to write any code. ## Running the Thrift JDBC/ODBC server The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) -in Hive 0.13. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.13. +in Hive 1.2.1 You can test the JDBC server with the beeline script that comes with either Spark or Hive 1.2.1. To start the JDBC/ODBC server, run the following in the Spark directory: @@ -2260,8 +2260,10 @@ Several caching related features are not supported yet: ## Compatibility with Apache Hive -Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Spark -SQL is based on Hive 0.12.0 and 0.13.1. +Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. +Currently Hive SerDes and UDFs are based on Hive 1.2.1, +and Spark SQL can be connected to different versions of Hive Metastore +(from 0.12.0 to 1.2.1. Also see http://spark.apache.org/docs/latest/sql-programming-guide.html#interacting-with-different-versions-of-hive-metastore). #### Deploying in Existing Hive Warehouses From d83b6aae8b4357c56779cc98804eb350ab8af62d Mon Sep 17 00:00:00 2001 From: Alexis Seigneurin Date: Sat, 19 Sep 2015 12:01:22 +0100 Subject: [PATCH 0690/1824] Fixed links to the API Submitting this change on the master branch as requested in https://github.com/apache/spark/pull/8819#issuecomment-141505941 Author: Alexis Seigneurin Closes #8838 from aseigneurin/patch-2. --- docs/ml-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c5d7f990021f..0427ac6695aa 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -619,13 +619,13 @@ for row in selected.collect(): An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. `Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). +Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. `CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. -The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) -for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` method in each of these evaluators. From e789000b88a6bd840f821c53f42c08b97dc02496 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 19 Sep 2015 18:22:43 -0700 Subject: [PATCH 0691/1824] [SPARK-10155] [SQL] Change SqlParser to object to avoid memory leak Since `scala.util.parsing.combinator.Parsers` is thread-safe since Scala 2.10 (See [SI-4929](https://issues.scala-lang.org/browse/SI-4929)), we can change SqlParser to object to avoid memory leak. I didn't change other subclasses of `scala.util.parsing.combinator.Parsers` because there is only one instance in one SQLContext, which should not be an issue. Author: zsxwing Closes #8357 from zsxwing/sql-memory-leak. --- .../apache/spark/sql/catalyst/AbstractSparkSQLParser.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/ParserDialect.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 6 +++--- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/functions.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 6 +++--- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 4 ++-- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 5898a5f93f38..2bac08eac4fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def parse(input: String): LogicalPlan = { + def parse(input: String): LogicalPlan = synchronized { // Initialize the Keywords. initLexical phrase(start)(new lexical.Scanner(input)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index 554fb4eb25eb..e21d3c05464b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -61,7 +61,7 @@ abstract class ParserDialect { */ private[spark] class DefaultParserDialect extends ParserDialect { @transient - protected val sqlParser = new SqlParser + protected val sqlParser = SqlParser override def parse(sqlText: String): LogicalPlan = { sqlParser.parse(sqlText) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index f2498861c957..dfab2398857e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -37,9 +37,9 @@ import org.apache.spark.unsafe.types.CalendarInterval * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends AbstractSparkSQLParser with DataTypeParser { +object SqlParser extends AbstractSparkSQLParser with DataTypeParser { - def parseExpression(input: String): Expression = { + def parseExpression(input: String): Expression = synchronized { // Initialize the Keywords. initLexical phrase(projection)(new lexical.Scanner(input)) match { @@ -48,7 +48,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } - def parseTableIdentifier(input: String): TableIdentifier = { + def parseTableIdentifier(input: String): TableIdentifier = synchronized { // Initialize the Keywords. initLexical phrase(tableIdentifier)(new lexical.Scanner(input)) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3e61123c145c..8f737c202393 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -720,7 +720,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(new SqlParser().parseExpression(expr)) + Column(SqlParser.parseExpression(expr)) }: _*) } @@ -745,7 +745,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** @@ -769,7 +769,7 @@ class DataFrame private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 745bb4ec9cf1..03e973666e88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -163,7 +163,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(new SqlParser().parseTableIdentifier(tableName)) + insertInto(SqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -197,7 +197,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(new SqlParser().parseTableIdentifier(tableName)) + saveAsTable(SqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e3fdd782e6ff..f099940800cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -590,7 +590,7 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -636,7 +636,7 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -732,7 +732,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(new SqlParser().parseTableIdentifier(tableName)) + table(SqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 60d9c509104d..2467b4e48415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -823,7 +823,7 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d37ba5ddc2d8..c12a73486332 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -291,12 +291,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -311,7 +311,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { */ @Experimental def analyze(tableName: String) { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0a5569b0a444..0c1b41e3377e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -199,7 +199,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive options: Map[String, String], isExternal: Boolean): Unit = { createDataSourceTable( - new SqlParser().parseTableIdentifier(tableName), + SqlParser.parseTableIdentifier(tableName), userSpecifiedSchema, partitionColumns, provider, @@ -375,7 +375,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } def hiveDefaultTableFilePath(tableName: String): String = { - hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + hiveDefaultTableFilePath(SqlParser.parseTableIdentifier(tableName)) } def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { From 2117eea71ece825fbc3797c8b38184ae221f5223 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 19 Sep 2015 21:40:21 -0700 Subject: [PATCH 0692/1824] [SPARK-10710] Remove ability to disable spilling in core and SQL It does not make much sense to set `spark.shuffle.spill` or `spark.sql.planner.externalSort` to false: I believe that these configurations were initially added as "escape hatches" to guard against bugs in the external operators, but these operators are now mature and well-tested. In addition, these configurations are not handled in a consistent way anymore: SQL's Tungsten codepath ignores these configurations and will continue to use spilling operators. Similarly, Spark Core's `tungsten-sort` shuffle manager does not respect `spark.shuffle.spill=false`. This pull request removes these configurations, adds warnings at the appropriate places, and deletes a large amount of code which was only used in code paths that did not support spilling. Author: Josh Rosen Closes #8831 from JoshRosen/remove-ability-to-disable-spilling. --- .../scala/org/apache/spark/Aggregator.scala | 59 +++++-------------- .../org/apache/spark/rdd/CoGroupedRDD.scala | 40 ++++--------- .../shuffle/hash/HashShuffleManager.scala | 8 ++- .../shuffle/sort/SortShuffleManager.scala | 10 +++- .../util/collection/ExternalSorter.scala | 6 -- .../spark/deploy/SparkSubmitSuite.scala | 22 +++---- docs/configuration.md | 14 +---- docs/sql-programming-guide.md | 7 --- python/pyspark/rdd.py | 25 +++----- python/pyspark/shuffle.py | 30 ---------- python/pyspark/tests.py | 13 +--- .../scala/org/apache/spark/sql/SQLConf.scala | 8 +-- .../spark/sql/execution/SparkStrategies.scala | 2 - .../apache/spark/sql/execution/commands.scala | 9 +++ .../org/apache/spark/sql/execution/sort.scala | 30 +--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 ++------ .../execution/RowFormatConvertersSuite.scala | 2 +- .../spark/sql/execution/SortSuite.scala | 4 +- 18 files changed, 81 insertions(+), 234 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 289aab9bd9e5..7196e57d5d2e 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import org.apache.spark.util.collection.ExternalAppendOnlyMap /** * :: DeveloperApi :: @@ -34,59 +34,30 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. - private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) - @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = combineValuesByKey(iter, null) - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], - context: TaskContext): Iterator[(K, C)] = { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineValuesByKey( + iter: Iterator[_ <: Product2[K, V]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) - : Iterator[(K, C)] = - { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kc: Product2[K, C] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 - } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineCombinersByKey( + iter: Iterator[_ <: Product2[K, C]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } /** Update task metrics after populating the external map. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7bad749d5832..935c3babd8ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} +import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer @@ -128,8 +128,6 @@ class CoGroupedRDD[K: ClassTag]( override val partitioner: Some[Partitioner] = Some(part) override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { - val sparkConf = SparkEnv.get.conf - val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length @@ -150,34 +148,16 @@ class CoGroupedRDD[K: ClassTag]( rddIterators += ((it, depNum)) } - if (!externalSorting) { - val map = new AppendOnlyMap[K, CoGroupCombiner] - val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) - } - val getCombiner: K => CoGroupCombiner = key => { - map.changeValue(key, update) - } - rddIterators.foreach { case (it, depNum) => - while (it.hasNext) { - val kv = it.next() - getCombiner(kv._1)(depNum) += kv._2 - } - } - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) - } else { - val map = createExternalMap(numRdds) - for ((it, depNum) <- rddIterators) { - map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) - } - context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) + val map = createExternalMap(numRdds) + for ((it, depNum) <- rddIterators) { + map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index c089088f409d..0b46634b8b46 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,7 +24,13 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d7fab351ca3b..476cc1f303da 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,11 +19,17 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ import org.apache.spark.shuffle.hash.HashShuffleReader -private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 31230d5978b2..2a30f751ff03 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -116,8 +116,6 @@ private[spark] class ExternalSorter[K, V, C]( private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() - private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 @@ -229,10 +227,6 @@ private[spark] class ExternalSorter[K, V, C]( * @param usingMap whether we're using a map or buffer as our current in-memory collection */ private def maybeSpillCollection(usingMap: Boolean): Unit = { - if (!spillingEnabled) { - return - } - var estimatedSize = 0L if (usingMap) { estimatedSize = map.estimateSize() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1110ca6051a4..1fd470cd3b01 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -147,7 +147,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -166,7 +166,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) sysProps("spark.app.name") should be ("beauty") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") sysProps.keys should not contain ("spark.jars") } @@ -185,7 +185,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -206,7 +206,7 @@ class SparkSubmitSuite sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone cluster mode") { @@ -229,7 +229,7 @@ class SparkSubmitSuite "--supervise", "--driver-memory", "4g", "--driver-cores", "5", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -253,9 +253,9 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.memory") sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.ui.enabled") sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -266,7 +266,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -277,7 +277,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -288,7 +288,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -299,7 +299,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { diff --git a/docs/configuration.md b/docs/configuration.md index 3700051efb44..5ec097c78aa3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -69,7 +69,7 @@ val sc = new SparkContext(new SparkConf()) Then, you can supply configuration values at runtime: {% highlight bash %} -./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false +./bin/spark-submit --name "My app" --master local[4] --conf spark.eventLog.enabled=false --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} @@ -449,8 +449,8 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.memoryFraction 0.2 - Fraction of Java heap to use for aggregation and cogroups during shuffles, if - spark.shuffle.spill is true. At any given time, the collective size of + Fraction of Java heap to use for aggregation and cogroups during shuffles. + At any given time, the collective size of all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will begin to spill to disk. If spills are often, consider increasing this value at the expense of spark.storage.memoryFraction. @@ -483,14 +483,6 @@ Apart from these, the following properties are also available, and may be useful map-side aggregation and there are at most this many reduce partitions. - - spark.shuffle.spill - true - - If set to "true", limits the amount of memory used during reduces by spilling data out to disk. - This spilling threshold is specified by spark.shuffle.memoryFraction. - - spark.shuffle.spill.compress true diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 82d4243cc6b2..7ae9244c271e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1936,13 +1936,6 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - - spark.sql.planner.externalSort - true - - When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. - - # Distributed SQL Engine diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ab5aab1e115f..73d7d9a5692a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -48,7 +48,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ +from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync @@ -580,12 +580,11 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true") memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -610,12 +609,11 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: @@ -1770,13 +1768,11 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = self._can_spill() memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1784,8 +1780,7 @@ def combineLocally(iterator): shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): - merger = ExternalMerger(agg, memory, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory, serializer) merger.mergeCombiners(iterator) return merger.items() @@ -1824,9 +1819,6 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) - def _can_spill(self): - return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" - def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) @@ -1857,14 +1849,12 @@ def mergeCombiners(a, b): a.extend(b) return a - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combine(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1872,8 +1862,7 @@ def combine(iterator): shuffled = locally_combined.partitionBy(numPartitions) def groupByKey(it): - merger = ExternalGroupBy(agg, memory, serializer)\ - if spill else InMemoryMerger(agg) + merger = ExternalGroupBy(agg, memory, serializer) merger.mergeCombiners(it) return merger.items() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index b8118bdb7ca7..e974cda9fc3e 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -131,36 +131,6 @@ def items(self): raise NotImplementedError -class InMemoryMerger(Merger): - - """ - In memory merger based on in-memory dict. - """ - - def __init__(self, aggregator): - Merger.__init__(self, aggregator) - self.data = {} - - def mergeValues(self, iterator): - """ Combine the items by creator and combiner """ - # speed up attributes lookup - d, creator = self.data, self.agg.createCombiner - comb = self.agg.mergeValue - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else creator(v) - - def mergeCombiners(self, iterator): - """ Merge the combined items by mergeCombiner """ - # speed up attributes lookup - d, comb = self.data, self.agg.mergeCombiners - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - - def items(self): - """ Return the merged items ad iterator """ - return iter(self.data.items()) - - def _compressed_serializer(self, serializer=None): # always use PickleSerializer to simplify implementation ser = PickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 647504c32f15..f11aaf001c8d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -62,7 +62,7 @@ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -95,17 +95,6 @@ def setUp(self): lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) - def test_in_memory(self): - m = InMemoryMerger(self.agg) - m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9de75f4c4d08..b9fb90d96420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -330,11 +330,6 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. - val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort", - defaultValue = Some(true), - doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" + - " memory.") - val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", defaultValue = Some(true), doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") @@ -422,6 +417,7 @@ private[spark] object SQLConf { object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + val EXTERNAL_SORT = "spark.sql.planner.externalSort" } } @@ -476,8 +472,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5e40d7768904..41b215c79296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,8 +312,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && TungstenSort.supportsSchema(child.schema)) { execution.TungstenSort(sortExprs, global, child) - } else if (sqlContext.conf.externalSortEnabled) { - execution.ExternalSort(sortExprs, global, child) } else { execution.Sort(sortExprs, global, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 95209e663451..af28e2dfa418 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -105,6 +105,15 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " + + s"External sort will continue to be used.") + Seq(Row(SQLConf.Deprecated.EXTERNAL_SORT, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 40ef7c3b5353..27f26245a5ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -31,38 +31,12 @@ import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} // This file defines various sort operators. //////////////////////////////////////////////////////////////////////////////////////////////////// - -/** - * Performs a sort on-heap. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - iterator.map(_.copy()).toArray.sorted(ordering).iterator - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - /** * Performs a sort, spilling to disk as needed. * @param global when true performs a global sort of all partitions by shuffling the data first * if necessary. */ -case class ExternalSort( +case class Sort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan) @@ -93,7 +67,7 @@ case class ExternalSort( } /** - * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Optimized version of [[Sort]] that operates on binary data (implemented as part of * Project Tungsten). * * @param global when true performs a global sort of all partitions by shuffling the data first diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f9981356f364..05b4127cbcaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -581,28 +581,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } - test("sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") { - sortTest() - } - } - test("external sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") { - sortTest() - } - } - - test("SPARK-6927 sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } + sortTest() } test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { sortTest() } } @@ -1731,10 +1715,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("external sorting updates peak execution memory") { - withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { - sortTest() - } + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sortTest() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 4492e37ad01f..5dc37e5c3c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -32,7 +32,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { case c: ConvertToSafe => c } - private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsSafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 3073d492e613..847c188a3033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -36,13 +36,13 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + Sort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + Sort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } From 1aa9e50256988533fa54584b49dbc408a14438ee Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 20 Sep 2015 16:05:12 -0700 Subject: [PATCH 0693/1824] [SPARK-5905] [MLLIB] Note requirements for certain RowMatrix methods in docs Note methods that fail for cols > 65535; note that SVD does not require n >= m CC mengxr Author: Sean Owen Closes #8839 from srowen/SPARK-5905. --- .../spark/mllib/linalg/distributed/RowMatrix.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index e55ef26858ad..7c7d900af3d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -109,7 +109,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. + * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with + * more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { @@ -150,7 +151,8 @@ class RowMatrix @Since("1.0.0") ( * - s is a Vector of size k, holding the singular values in descending order, * - V is a Matrix of size n x k that satisfies V' * V = eye(k). * - * We assume n is smaller than m. The singular values and the right singular vectors are derived + * We assume n is smaller than m, though this is not strictly required. + * The singular values and the right singular vectors are derived * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, the matrix * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined @@ -320,7 +322,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the covariance matrix, treating each row as an observation. + * Computes the covariance matrix, treating each row as an observation. Note that this cannot + * be computed on matrices with more than 65535 columns. * @return a local dense matrix of size n x n */ @Since("1.0.0") @@ -374,6 +377,8 @@ class RowMatrix @Since("1.0.0") ( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * + * Note that this cannot be computed on matrices with more than 65535 columns. + * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components */ From 0c498717ba9622b6c889e701e8eed5ef9215c030 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Sun, 20 Sep 2015 16:16:31 -0700 Subject: [PATCH 0694/1824] [SPARK-10715] [ML] Duplicate initialization flag in WeightedLeastSquare There are duplicate set of initialization flag in `WeightedLeastSquares#add`. `initialized` is already set in `init(Int)`. Author: lewuathe Closes #8837 from Lewuathe/duplicate-initialization-flag. --- .../scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 0ff8931b0bab..4374e9963156 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -193,7 +193,6 @@ private[ml] object WeightedLeastSquares { val ak = a.size if (!initialized) { init(ak) - initialized = true } assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.") count += 1L From 01440395176bdbb2662480f03b27851cb860f385 Mon Sep 17 00:00:00 2001 From: vinodkc Date: Sun, 20 Sep 2015 22:55:24 -0700 Subject: [PATCH 0695/1824] [SPARK-10631] [DOCUMENTATION, MLLIB, PYSPARK] Added documentation for few APIs There are some missing API docs in pyspark.mllib.linalg.Vector (including DenseVector and SparseVector). We should add them based on their Scala counterparts. Author: vinodkc Closes #8834 from vinodkc/fix_SPARK-10631. --- python/pyspark/mllib/linalg/__init__.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 4829acb16ed8..f929e3e96fbe 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -301,11 +301,14 @@ def __reduce__(self): return DenseVector, (self.array.tostring(),) def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros + """ return np.count_nonzero(self.array) def norm(self, p): """ - Calculte the norm of a DenseVector. + Calculates the norm of a DenseVector. >>> a = DenseVector([0, -1, 2, -3]) >>> a.norm(2) @@ -397,10 +400,16 @@ def squared_distance(self, other): return np.dot(diff, diff) def toArray(self): + """ + Returns an numpy.ndarray + """ return self.array @property def values(self): + """ + Returns a list of values + """ return self.array def __getitem__(self, item): @@ -479,8 +488,8 @@ def __init__(self, size, *args): :param size: Size of the vector. :param args: Active entries, as a dictionary {index: value, ...}, - a list of tuples [(index, value), ...], or a list of strictly i - ncreasing indices and a list of corresponding values [index, ...], + a list of tuples [(index, value), ...], or a list of strictly + increasing indices and a list of corresponding values [index, ...], [value, ...]. Inactive entries are treated as zeros. >>> SparseVector(4, {1: 1.0, 3: 5.5}) @@ -521,11 +530,14 @@ def __init__(self, size, *args): raise TypeError("indices array must be sorted") def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros. + """ return np.count_nonzero(self.values) def norm(self, p): """ - Calculte the norm of a SparseVector. + Calculates the norm of a SparseVector. >>> a = SparseVector(4, [0, 1], [3., -4.]) >>> a.norm(1) @@ -797,7 +809,7 @@ def sparse(size, *args): values (sorted by index). :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, + :param args: Non-zero entries, as a dictionary, list of tuples, or two sorted lists containing indices and values. >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) From 20a61dbd9b57957fcc5b58ef8935533914172b07 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 18:53:28 +0100 Subject: [PATCH 0696/1824] [SPARK-10626] [MLLIB] create java friendly method for random rdd SPARK-3136 added a large number of functions for creating Java RandomRDDs, but for people that want to use custom RandomDataGenerators we should make a Java friendly method. Author: Holden Karau Closes #8782 from holdenk/SPARK-10626-create-java-friendly-method-for-randomRDD. --- .../spark/mllib/random/RandomRDDs.scala | 52 ++++++++++++++++++- .../mllib/random/JavaRandomRDDsSuite.java | 30 +++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 4dd5ea214d67..f8ff26b5795b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} import org.apache.spark.rdd.RDD @@ -381,7 +382,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi @Since("1.1.0") @@ -394,6 +395,55 @@ object RandomRDDs { new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * :: DeveloperApi :: + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. + * + * @param jsc JavaSparkContext used to create the RDD. + * @param generator RandomDataGenerator used to populate the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int, + seed: Long): JavaRDD[T] = { + implicit val ctag: ClassTag[T] = fakeClassTag + val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed) + JavaRDD.fromRDD(rdd) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong()) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, 0); + } + // TODO Generate RDD[Vector] from multivariate distributions. /** diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index 33d81b1e9592..fce5f6712f46 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.random; +import java.io.Serializable; import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; @@ -231,4 +232,33 @@ public void testGammaVectorRDD() { } } + @Test + public void testArbitrary() { + long size = 10; + long seed = 1L; + int numPartitions = 0; + StringGenerator gen = new StringGenerator(); + JavaRDD rdd1 = randomJavaRDD(sc, gen, size); + JavaRDD rdd2 = randomJavaRDD(sc, gen, size, numPartitions); + JavaRDD rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(size, rdd.count()); + Assert.assertEquals(2, rdd.first().length()); + } + } +} + +// This is just a test generator, it always returns a string of 42 +class StringGenerator implements RandomDataGenerator, Serializable { + @Override + public String nextValue() { + return "42"; + } + @Override + public StringGenerator copy() { + return new StringGenerator(); + } + @Override + public void setSeed(long seed) { + } } From ebbf85f07bb8de0d566f1ae4b41f26421180bebe Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 21 Sep 2015 11:39:04 -0700 Subject: [PATCH 0697/1824] [SPARK-7989] [SPARK-10651] [CORE] [TESTS] Increase timeout to fix flaky tests I noticed only one block manager registered with master in an unsuccessful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/3534/) ``` 15/09/16 13:02:30.981 pool-1-thread-1-ScalaTest-running-BroadcastSuite INFO SparkContext: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 13:02:38.133 sparkDriver-akka.actor.default-dispatcher-19 INFO BlockManagerMasterEndpoint: Registering block manager localhost:48196 with 530.3 MB RAM, BlockManagerId(0, localhost, 48196) ``` In addition, the first block manager needed 7+ seconds to start. But the test expected 2 block managers so it failed. However, there was no exception in this log file. So I checked a successful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/3536/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/) and it needed 4-5 seconds to set up the local cluster: ``` 15/09/16 18:11:27.738 sparkWorker1-akka.actor.default-dispatcher-5 INFO Worker: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 18:11:30.838 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:54202 with 530.3 MB RAM, BlockManagerId(1, localhost, 54202) 15/09/16 18:11:32.112 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:32955 with 530.3 MB RAM, BlockManagerId(0, localhost, 32955) ``` In this build, the first block manager needed only 3+ seconds to start. Comparing these two builds, I guess it's possible that the local cluster in `BroadcastSuite` cannot be ready in 10 seconds if the Jenkins worker is busy. So I just increased the timeout to 60 seconds to see if this can fix the issue. Author: zsxwing Closes #8813 from zsxwing/fix-BroadcastSuite. --- .../scala/org/apache/spark/ExternalShuffleServiceSuite.scala | 2 +- .../test/scala/org/apache/spark/broadcast/BroadcastSuite.scala | 2 +- .../apache/spark/scheduler/SparkListenerWithClusterSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index e846a72c888c..231f4631e0a4 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -61,7 +61,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. // Therefore, we should wait until all slaves are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index fb7a8ae3f9d4..ba21075ce6be 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -311,7 +311,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up try { - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) _sc } catch { case e: Throwable => diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d1e23ed527ff..9fa885938291 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -43,7 +43,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext // This test will check if the number of executors received by "SparkListener" is same as the // number of all executors, so we need to wait until all executors are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) From ca9fe540fe04e2e230d1e76526b5502bab152914 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 21 Sep 2015 19:46:39 +0100 Subject: [PATCH 0698/1824] [SPARK-10662] [DOCS] Code snippets are not properly formatted in tables * Backticks are processed properly in Spark Properties table * Removed unnecessary spaces * See http://people.apache.org/~pwendell/spark-nightly/spark-master-docs/latest/running-on-yarn.html Author: Jacek Laskowski Closes #8795 from jaceklaskowski/docs-yarn-formatting. --- docs/configuration.md | 97 +++++++++++++++-------------- docs/programming-guide.md | 100 +++++++++++++++--------------- docs/running-on-mesos.md | 14 ++--- docs/running-on-yarn.md | 106 ++++++++++++++++---------------- docs/sql-programming-guide.md | 16 ++--- docs/submitting-applications.md | 8 +-- 6 files changed, 171 insertions(+), 170 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 5ec097c78aa3..b22587c70316 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -34,20 +34,20 @@ val conf = new SparkConf() val sc = new SparkContext(conf) {% endhighlight %} -Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may +Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may actually require one to prevent any sort of starvation issues. -Properties that specify some time duration should be configured with a unit of time. +Properties that specify some time duration should be configured with a unit of time. The following format is accepted: - + 25ms (milliseconds) 5s (seconds) 10m or 10min (minutes) 3h (hours) 5d (days) 1y (years) - - + + Properties that specify a byte size should be configured with a unit of size. The following format is accepted: @@ -140,7 +140,7 @@ of the most common options to set are: Amount of memory to use for the driver process, i.e. where SparkContext is initialized. (e.g. 1g, 2g). - +
      Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option @@ -207,7 +207,7 @@ Apart from these, the following properties are also available, and may be useful
      Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-class-path command line option or in + Instead, please set this through the --driver-class-path command line option or in your default properties file. @@ -216,10 +216,10 @@ Apart from these, the following properties are also available, and may be useful (none) A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. - +
      Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-java-options command line option or in + Instead, please set this through the --driver-java-options command line option or in your default properties file. @@ -228,10 +228,10 @@ Apart from these, the following properties are also available, and may be useful (none) Set a special library path to use when launching the driver JVM. - +
      Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-library-path command line option or in + Instead, please set this through the --driver-library-path command line option or in your default properties file. @@ -242,7 +242,7 @@ Apart from these, the following properties are also available, and may be useful (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading classes in the the driver. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. - + This is used in cluster mode only. @@ -250,8 +250,8 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to prepend to the classpath of executors. This exists primarily for - backwards-compatibility with older versions of Spark. Users typically should not need to set + Extra classpath entries to prepend to the classpath of executors. This exists primarily for + backwards-compatibility with older versions of Spark. Users typically should not need to set this option. @@ -259,9 +259,9 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraJavaOptions (none) - A string of extra JVM options to pass to executors. For instance, GC settings or other logging. - Note that it is illegal to set Spark properties or heap size settings with this option. Spark - properties should be set using a SparkConf object or the spark-defaults.conf file used with the + A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or heap size settings with this option. Spark + properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Heap size settings can be set with spark.executor.memory. @@ -305,7 +305,7 @@ Apart from these, the following properties are also available, and may be useful daily Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + Rolling is disabled by default. Valid values are daily, hourly, minutely or any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -330,13 +330,13 @@ Apart from these, the following properties are also available, and may be useful spark.python.profile false - Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, + Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), or it will be displayed before the driver exiting. It also can be dumped into disk by - `sc.dump_profiles(path)`. If some of the profile results had been displayed manually, + sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. - By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by - passing a profiler class in as a parameter to the `SparkContext` constructor. + By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by + passing a profiler class in as a parameter to the SparkContext constructor. @@ -460,11 +460,11 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.enabled false - Enables the external shuffle service. This service preserves the shuffle files written by - executors so the executors can be safely removed. This must be enabled if + Enables the external shuffle service. This service preserves the shuffle files written by + executors so the executors can be safely removed. This must be enabled if spark.dynamicAllocation.enabled is "true". The external shuffle service must be set up in order to enable it. See - dynamic allocation + dynamic allocation configuration and setup documentation for more information. @@ -747,9 +747,9 @@ Apart from these, the following properties are also available, and may be useful 1 in YARN mode, all the available cores on the worker in standalone mode. The number of cores to use on each executor. For YARN and standalone mode only. - - In standalone mode, setting this parameter allows an application to run multiple executors on - the same worker, provided that there are enough cores on that worker. Otherwise, only one + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one executor per application will run on each worker. @@ -893,14 +893,14 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.interval 1000s - This is set to a larger value to disable the transport failure detector that comes built in to - Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value reduces network overhead and a smaller value ( ~ 1 s) might be more - informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` - if you need to. A likely positive use case for using failure detector would be: a sensistive - failure detector can help evict rogue executors quickly. However this is usually not the case - as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling - this leads to a lot of exchanges of heart beats between nodes leading to flooding the network + This is set to a larger value to disable the transport failure detector that comes built in to + Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger + interval value reduces network overhead and a smaller value ( ~ 1 s) might be more + informative for Akka's failure detector. Tune this in combination of spark.akka.heartbeat.pauses + if you need to. A likely positive use case for using failure detector would be: a sensistive + failure detector can help evict rogue executors quickly. However this is usually not the case + as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling + this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those. @@ -909,9 +909,9 @@ Apart from these, the following properties are also available, and may be useful 6000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. - It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart + It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune - this along with `spark.akka.heartbeat.interval` if you need to. + this along with spark.akka.heartbeat.interval if you need to. @@ -978,7 +978,7 @@ Apart from these, the following properties are also available, and may be useful spark.network.timeout 120s - Default timeout for all network interactions. This config will be used in place of + Default timeout for all network interactions. This config will be used in place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, spark.storage.blockManagerSlaveTimeoutMs, spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or @@ -991,8 +991,8 @@ Apart from these, the following properties are also available, and may be useful Maximum number of retries when binding to a port before giving up. When a port is given a specific value (non 0), each subsequent retry will - increment the port used in the previous attempt by 1 before retrying. This - essentially allows it to try a range of ports from the start port specified + increment the port used in the previous attempt by 1 before retrying. This + essentially allows it to try a range of ports from the start port specified to port + maxRetries. @@ -1191,7 +1191,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.executorIdleTimeout 60s - If dynamic allocation is enabled and an executor has been idle for more than this duration, + If dynamic allocation is enabled and an executor has been idle for more than this duration, the executor will be removed. For more detail, see this description. @@ -1424,11 +1424,11 @@ Apart from these, the following properties are also available, and may be useful false Enables or disables Spark Streaming's internal backpressure mechanism (since 1.5). - This enables the Spark Streaming to control the receiving rate based on the + This enables the Spark Streaming to control the receiving rate based on the current batch scheduling delays and processing times so that the system receives - only as fast as the system can process. Internally, this dynamically sets the + only as fast as the system can process. Internally, this dynamically sets the maximum receiving rate of receivers. This rate is upper bounded by the values - `spark.streaming.receiver.maxRate` and `spark.streaming.kafka.maxRatePerPartition` + spark.streaming.receiver.maxRate and spark.streaming.kafka.maxRatePerPartition if they are set (see below). @@ -1542,15 +1542,15 @@ The following variables can be set in `spark-env.sh`: Environment VariableMeaning JAVA_HOME - Location where Java is installed (if it's not on your default `PATH`). + Location where Java is installed (if it's not on your default PATH). PYSPARK_PYTHON - Python binary executable to use for PySpark in both driver and workers (default is `python`). + Python binary executable to use for PySpark in both driver and workers (default is python). PYSPARK_DRIVER_PYTHON - Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). + Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). SPARK_LOCAL_IP @@ -1580,4 +1580,3 @@ Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can config To specify a different configuration directory other than the default "SPARK_HOME/conf", you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) from this directory. - diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 4cf83bb39263..8ad238315f12 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -182,8 +182,8 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. You can also add dependencies -(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly four cores, use: @@ -217,7 +217,7 @@ context connects to using the `--master` argument, and you can add Python .zip, to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies (e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) -can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: @@ -249,8 +249,8 @@ the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support $ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} -After the IPython Notebook server is launched, you can create a new "Python 2" notebook from -the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of your notebook before you start to try Spark from the IPython notebook.
    @@ -418,9 +418,9 @@ Apart from text files, Spark's Python API also supports several other data forma **Writable Support** -PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the -resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, -PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following +PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the +resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, +PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following Writables are automatically converted: @@ -435,9 +435,9 @@ Writables are automatically converted:
    MapWritabledict
    -Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, -users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default -converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get +Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, +users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default +converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get Python `array.array` for arrays of primitive types, users need to specify custom converters. **Saving and Loading SequenceFiles** @@ -454,7 +454,7 @@ classes can be specified, but for standard Writables this is not required. **Saving and Loading Other Hadoop Input/Output Formats** -PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. +PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. If required, a Hadoop configuration can be passed in as a Python dict. Here is an example using the Elasticsearch ESInputFormat: @@ -474,15 +474,15 @@ Note that, if the InputFormat simply depends on a Hadoop configuration and/or in the key and value classes can easily be converted according to the above table, then this approach should work well for such cases. -If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to +If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to transform that data on the Scala/Java side to something which can be handled by Pyrolite's pickler. -A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided -for this. Simply extend this trait and implement your transformation code in the ```convert``` -method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark +A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided +for this. Simply extend this trait and implement your transformation code in the ```convert``` +method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark classpath. -See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and -the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) +See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and +the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters.
    @@ -758,7 +758,7 @@ One of the harder things about Spark is understanding the scope and life cycle o #### Example -Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): +Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN):
    @@ -777,7 +777,7 @@ println("Counter value: " + counter)
    {% highlight java %} int counter = 0; -JavaRDD rdd = sc.parallelize(data); +JavaRDD rdd = sc.parallelize(data); // Wrong: Don't do this!! rdd.foreach(x -> counter += x); @@ -803,7 +803,7 @@ print("Counter value: " + counter) #### Local vs. cluster modes -The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. +The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on seperate worker nodes each have their own copy of the closure. @@ -813,9 +813,9 @@ To ensure well-defined behavior in these sorts of scenarios one should use an [` In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. -#### Printing elements of an RDD +#### Printing elements of an RDD Another common idiom is attempting to print out the elements of an RDD using `rdd.foreach(println)` or `rdd.map(println)`. On a single machine, this will generate the expected output and print all the RDD's elements. However, in `cluster` mode, the output to `stdout` being called by the executors is now writing to the executor's `stdout` instead, not the one on the driver, so `stdout` on the driver won't show these! To print all elements on the driver, one can use the `collect()` method to first bring the RDD to the driver node thus: `rdd.collect().foreach(println)`. This can cause the driver to run out of memory, though, because `collect()` fetches the entire RDD to a single machine; if you only need to print a few elements of the RDD, a safer approach is to use the `take()`: `rdd.take(100).foreach(println)`. - + ### Working with Key-Value Pairs
    @@ -859,7 +859,7 @@ only available on RDDs of key-value pairs. The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. -In Java, key-value pairs are represented using the +In Java, key-value pairs are represented using the [scala.Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) class from the Scala standard library. You can simply call `new Tuple2(a, b)` to create a tuple, and access its fields later with `tuple._1()` and `tuple._2()`. @@ -974,7 +974,7 @@ for details. groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
    Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or aggregateByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
    Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. @@ -1025,7 +1025,7 @@ for details. repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, - sort records by their keys. This is more efficient than calling repartition and then sorting within + sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -1038,7 +1038,7 @@ RDD API doc [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), [Python](api/python/pyspark.html#pyspark.RDD), [R](api/R/index.html)) - + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1094,7 +1094,7 @@ for details. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems.
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1118,13 +1118,13 @@ co-located to compute the result. In Spark, data is generally not distributed across partitions to be in the necessary place for a specific operation. During computations, a single task will operate on a single partition - thus, to organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an -all-to-all operation. It must read from all partitions to find all the values for all keys, -and then bring together values across partitions to compute the final result for each key - +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - this is called the **shuffle**. Although the set of elements in each partition of newly shuffled data will be deterministic, and so -is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably -ordered data following shuffle then it's possible to use: +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: * `mapPartitions` to sort each partition using, for example, `.sorted` * `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning @@ -1141,26 +1141,26 @@ network I/O. To organize data for the shuffle, Spark generates sets of tasks - * organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from MapReduce and does not directly relate to Spark's `map` and `reduce` operations. -Internally, results from individual map tasks are kept in memory until they can't fit. Then, these -are sorted based on the target partition and written to a single file. On the reduce side, tasks +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks read the relevant sorted blocks. - -Certain shuffle operations can consume significant amounts of heap memory since they employ -in-memory data structures to organize records before or after transferring them. Specifically, -`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations -generate these on the reduce side. When data does not fit in memory Spark will spill these tables + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are preserved until the corresponding RDDs are no longer used and are garbage collected. -This is done so the shuffle files don't need to be re-created if the lineage is re-computed. -Garbage collection may happen only after a long period time, if the application retains references -to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the -'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). ## RDD Persistence @@ -1246,7 +1246,7 @@ efficiency. We recommend going through the following process to select one: This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. * If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to -make the objects much more space-efficient, but still reasonably fast to access. +make the objects much more space-efficient, but still reasonably fast to access. * Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large amount of the data. Otherwise, recomputing a partition may be as fast as reading it from @@ -1345,7 +1345,7 @@ Accumulators are variables that are only "added" to through an associative opera therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. If accumulators are created with a name, they will be -displayed in Spark's UI. This can be useful for understanding the progress of +displayed in Spark's UI. This can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python). An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks @@ -1474,8 +1474,8 @@ vecAccum = sc.accumulator(Vector(...), VectorAccumulatorParam())
    -For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator -will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware +For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator +will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware of that each task's update may be applied more than once if tasks or job stages are re-executed. Accumulators do not change the lazy evaluation model of Spark. If they are being updated within an operation on an RDD, their value is only updated once that RDD is computed as part of an action. Consequently, accumulator updates are not guaranteed to be executed when made within a lazy transformation like `map()`. The below code fragment demonstrates this property: @@ -1486,7 +1486,7 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being {% highlight scala %} val accum = sc.accumulator(0) data.map { x => accum += x; f(x) } -// Here, accum is still 0 because no actions have caused the `map` to be computed. +// Here, accum is still 0 because no actions have caused the map to be computed. {% endhighlight %}
    @@ -1553,7 +1553,7 @@ Several changes were made to the Java API: code that `extends Function` should `implement Function` instead. * New variants of the `map` transformations, like `mapToPair` and `mapToDouble`, were added to create RDDs of special data types. -* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning +* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning `(Key, List)` pairs to `(Key, Iterable)`.
    diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 330c159c67bc..460a66f37dd6 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -245,7 +245,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - If set to "true", runs over Mesos clusters in + If set to true, runs over Mesos clusters in "coarse-grained" sharing mode, where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use @@ -254,16 +254,16 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.extra.cores - 0 + 0 Set the extra amount of cpus to request per task. This setting is only used for Mesos coarse grain mode. The total amount of cores requested per task is the number of cores in the offer plus the extra cores configured. - Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. spark.mesos.mesosExecutor.cores - 1.0 + 1.0 (Fine-grained mode only) Number of cores to give each Mesos executor. This does not include the cores used to run the Spark tasks. In other words, even if no Spark task @@ -287,7 +287,7 @@ See the [configuration page](configuration.html) for information on Spark config Set the list of volumes which will be mounted into the Docker image, which was set using spark.mesos.executor.docker.image. The format of this property is a comma-separated list of - mappings following the form passed to docker run -v. That is they take the form: + mappings following the form passed to docker run -v. That is they take the form:
    [host_path:]container_path[:ro|:rw]
    @@ -318,7 +318,7 @@ See the [configuration page](configuration.html) for information on Spark config executor memory * 0.10, with minimum of 384 The amount of additional memory, specified in MB, to be allocated per executor. By default, - the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set, + the overhead will be larger of either 384 or 10% of spark.executor.memory. If set, the final overhead will be this value. @@ -339,7 +339,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.secret - (none)/td> + (none) Set the secret with which Spark framework will use to authenticate with Mesos. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3a961d245f3d..0e25ccf512c0 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -23,7 +23,7 @@ Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.ht To launch a Spark application in `yarn-cluster` mode: $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - + For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ @@ -43,7 +43,7 @@ To launch a Spark application in `yarn-client` mode, do the same, but replace `y ## Adding Other JARs -In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. $ ./bin/spark-submit --class my.main.Class \ --master yarn-cluster \ @@ -64,16 +64,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes # Debugging your Application -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the `yarn logs` command. yarn logs -applicationId - + will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +large value (e.g. `36000`), and then access the application cache through `yarn.nodemanager.local-dirs` on the nodes on which containers are launched. This directory contains the launch script, JARs, and all environment variables used for launching each container. This process is useful for debugging classpath problems in particular. (Note that enabling this requires admin privileges on cluster @@ -92,7 +92,7 @@ Note that for the first option, both executors and the application master will s log4j configuration, which may cause issues when they run on the same node (e.g. trying to write to the same log file). -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your `log4j.properties`. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming applications, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log files, and logs can be accessed using YARN's log utility. #### Spark Properties @@ -100,24 +100,26 @@ If you need a reference to the proper location to put log files in the YARN so t Property NameDefaultMeaning spark.yarn.am.memory - 512m + 512m Amount of memory to use for the YARN Application Master in client mode, in the same format as JVM memory strings (e.g. 512m, 2g). In cluster mode, use spark.driver.memory instead. +

    + Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. spark.driver.cores - 1 + 1 Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead. + Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. + In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. spark.yarn.am.cores - 1 + 1 Number of cores to use for the YARN Application Master in client mode. In cluster mode, use spark.driver.cores instead. @@ -125,39 +127,39 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.am.waitTime - 100s + 100s - In `yarn-cluster` mode, time for the application master to wait for the - SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait + In yarn-cluster mode, time for the YARN Application Master to wait for the + SparkContext to be initialized. In yarn-client mode, time for the YARN Application Master to wait for the driver to connect to it. spark.yarn.submit.file.replication - The default HDFS replication (usually 3) + The default HDFS replication (usually 3) HDFS replication level for the files uploaded into HDFS for the application. These include things like the Spark jar, the app jar, and any distributed cache files/archives. spark.yarn.preserve.staging.files - false + false - Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. + Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. spark.yarn.scheduler.heartbeat.interval-ms - 3000 + 3000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. - The value is capped at half the value of YARN's configuration for the expiry interval - (yarn.am.liveness-monitor.expiry-interval-ms). + The value is capped at half the value of YARN's configuration for the expiry interval, i.e. + yarn.am.liveness-monitor.expiry-interval-ms. spark.yarn.scheduler.initial-allocation.interval - 200ms + 200ms The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager when there are pending container allocation requests. It should be no larger than @@ -177,8 +179,8 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. - For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. + The address of the Spark history server, e.g. host.com:18080. The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For example, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to ${hadoopconf-yarn.resourcemanager.hostname}:18080. @@ -197,42 +199,42 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances - 2 + 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). spark.yarn.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead AM memory * 0.10, with minimum of 384 - Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. + Same as spark.yarn.driver.memoryOverhead, but for the YARN Application Master in client mode. spark.yarn.am.port (random) - Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. + Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the YARN Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. spark.yarn.queue - default + default The name of the YARN queue to which the application is submitted. @@ -245,18 +247,18 @@ If you need a reference to the proper location to put log files in the YARN so t By default, Spark on YARN will use a Spark jar installed locally, but the Spark jar can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a jar on HDFS, for example, - set this configuration to "hdfs:///some/path". + set this configuration to hdfs:///some/path. spark.yarn.access.namenodes (none) - A list of secure HDFS namenodes your Spark application is going to access. For - example, `spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032`. - The Spark application must have acess to the namenodes listed and Kerberos must - be properly configured to be able to access them (either in the same realm or in - a trusted realm). Spark acquires security tokens for each of the namenodes so that + A comma-separated list of secure HDFS namenodes your Spark application is going to access. For + example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032. + The Spark application must have access to the namenodes listed and Kerberos must + be properly configured to be able to access them (either in the same realm or in + a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. @@ -264,18 +266,18 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.appMasterEnv.[EnvironmentVariableName] (none) - Add the environment variable specified by EnvironmentVariableName to the - Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In `yarn-cluster` mode this controls - the environment of the SPARK driver and in `yarn-client` mode it only controls - the environment of the executor launcher. + Add the environment variable specified by EnvironmentVariableName to the + Application Master process launched on YARN. The user can specify multiple of + these and to set multiple environment variables. In yarn-cluster mode this controls + the environment of the Spark driver and in yarn-client mode it only controls + the environment of the executor launcher. spark.yarn.containerLauncherMaxThreads - 25 + 25 - The maximum number of threads to use in the application master for launching executor containers. + The maximum number of threads to use in the YARN Application Master for launching executor containers. @@ -283,19 +285,19 @@ If you need a reference to the proper location to put log files in the YARN so t (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use `spark.driver.extraJavaOptions` instead. + In cluster mode, use spark.driver.extraJavaOptions instead. spark.yarn.am.extraLibraryPath (none) - Set a special library path to use when launching the application master in client mode. + Set a special library path to use when launching the YARN Application Master in client mode. spark.yarn.maxAppAttempts - yarn.resourcemanager.am.max-attempts in YARN + yarn.resourcemanager.am.max-attempts in YARN The maximum number of attempts that will be made to submit the application. It should be no larger than the global number of max attempts in the YARN configuration. @@ -303,10 +305,10 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.submit.waitAppCompletion - true + true In YARN cluster mode, controls whether the client waits to exit until the application completes. - If set to true, the client process will stay alive reporting the application's status. + If set to true, the client process will stay alive reporting the application's status. Otherwise, the client process will exit after submission. @@ -332,7 +334,7 @@ If you need a reference to the proper location to put log files in the YARN so t (none) The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, for renewing the login tickets and the delegation tokens periodically. @@ -371,14 +373,14 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.security.tokens.${service}.enabled - true + true Controls whether to retrieve delegation tokens for non-HDFS services when security is enabled. By default, delegation tokens for all supported services are retrieved when those services are configured, but it's possible to disable that behavior if it somehow conflicts with the application being run.

    - Currently supported services are: hive, hbase + Currently supported services are: hive, hbase @@ -387,5 +389,5 @@ If you need a reference to the proper location to put log files in the YARN so t - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7ae9244c271e..a1cbc7de97c6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1676,7 +1676,7 @@ results <- collect(sql(sqlContext, "FROM src SELECT key, value")) ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. Note that independent of the version of Hive that is being used to talk to the metastore, internally Spark SQL will compile against Hive 1.2.1 and use those classes for internal execution (serdes, UDFs, UDAFs, etc). @@ -1706,8 +1706,8 @@ The following options can be used to configure the version of Hive that is used either 1.2.1 or not defined.

  • maven
  • Use Hive jars of specified version downloaded from Maven repositories. This configuration - is not generally recommended for production deployments. -
  • A classpath in the standard format for the JVM. This classpath must include all of Hive + is not generally recommended for production deployments. +
  • A classpath in the standard format for the JVM. This classpath must include all of Hive and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure they are packaged with you application.
  • @@ -1806,7 +1806,7 @@ the Data Sources API. The following options are supported:
    {% highlight scala %} -val jdbcDF = sqlContext.read.format("jdbc").options( +val jdbcDF = sqlContext.read.format("jdbc").options( Map("url" -> "jdbc:postgresql:dbserver", "dbtable" -> "schema.tablename")).load() {% endhighlight %} @@ -2023,11 +2023,11 @@ options. - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with code generation for expression evaluation. These features can both be disabled by setting - `spark.sql.tungsten.enabled` to `false. - - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + `spark.sql.tungsten.enabled` to `false`. + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting `spark.sql.parquet.mergeSchema` to `true`. - - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or - access nested values. For example `df['table.column.nestedField']`. However, this means that if + - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or + access nested values. For example `df['table.column.nestedField']`. However, this means that if your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 7ea4d6f1a3f8..915be0f47915 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -103,7 +103,7 @@ run it with `--help`. Here are a few examples of common options: export HADOOP_CONF_DIR=XXX ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ # can also be `yarn-client` for client mode + --master yarn-cluster \ # can also be yarn-client for client mode --executor-memory 20G \ --num-executors 50 \ /path/to/examples.jar \ @@ -174,9 +174,9 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. -Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates -with `--packages`. All transitive dependencies will be handled when using this command. Additional -repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries From 331f0b10f78a37d96d3e573d211d74a0935265db Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Mon, 21 Sep 2015 12:09:00 -0700 Subject: [PATCH 0699/1824] [SPARK-9642] [ML] LinearRegression should supported weighted data In many modeling application, data points are not necessarily sampled with equal probabilities. Linear regression should support weighting which account the over or under sampling. work in progress. Author: Meihua Wu Closes #8631 from rotationsymmetry/SPARK-9642. --- .../ml/regression/LinearRegression.scala | 164 +++++++++++------- .../ml/regression/LinearRegressionSuite.scala | 88 ++++++++++ project/MimaExcludes.scala | 8 +- 3 files changed, 191 insertions(+), 69 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index e4602d36ccc8..78a67c5fdab2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -31,21 +31,29 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.functions.{col, udf, lit} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol - with HasFitIntercept with HasStandardization + with HasFitIntercept with HasStandardization with HasWeightCol + +/** + * Class that represents an instance of weighted data point with label and features. + * + * TODO: Refactor this class to proper place. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[regression] case class Instance(label: Double, weight: Double, features: Vector) /** * :: Experimental :: @@ -123,30 +131,43 @@ class LinearRegression(override val uid: String) def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + override protected def train(dataset: DataFrame): LinearRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist instances. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, statCounter) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new StatCounter))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), - (label: Double, features: Vector)) => - (summarizer.add(features), statCounter.merge(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), - (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => - (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) - }) - - val numFeatures = summarizer.mean.size - val yMean = statCounter.mean - val yStd = math.sqrt(statCounter.variance) + val (featuresSummarizer, ySummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), + c._2.add(Vectors.dense(instance.label), instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val numFeatures = featuresSummarizer.mean.size + val yMean = ySummarizer.mean(0) + val yStd = math.sqrt(ySummarizer.variance(0)) // If the yStd is zero, then the intercept is yMean with zero weights; // as a result, training is not needed. @@ -167,8 +188,8 @@ class LinearRegression(override val uid: String) return copyValues(model.setSummary(trainingSummary)) } - val featuresMean = summarizer.mean.toArray - val featuresStd = summarizer.variance.toArray.map(math.sqrt) + val featuresMean = featuresSummarizer.mean.toArray + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) // Since we implicitly do the feature scaling when we compute the cost function // to improve the convergence, the effective regParam will be changed. @@ -318,7 +339,8 @@ class LinearRegressionModel private[ml] ( /** * :: Experimental :: - * Linear regression training results. + * Linear regression training results. Currently, the training summary ignores the + * training weights except for the objective trace. * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @@ -477,7 +499,7 @@ class LinearRegressionSummary private[regression] ( * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) * }}}, * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param labelStd The standard deviation value of the label. * @param labelMean The mean value of the label. * @param fitIntercept Whether to fit an intercept term. @@ -485,7 +507,7 @@ class LinearRegressionSummary private[regression] ( * @param featuresMean The mean values of the features. */ private class LeastSquaresAggregator( - weights: Vector, + coefficients: Vector, labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -493,26 +515,28 @@ private class LeastSquaresAggregator( featuresMean: Array[Double]) extends Serializable { private var totalCnt: Long = 0L + private var weightSum: Double = 0.0 private var lossSum = 0.0 - private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { - val weightsArray = weights.toArray.clone() + private val (effectiveCoefficientsArray: Array[Double], offset: Double, dim: Int) = { + val coefficientsArray = coefficients.toArray.clone() var sum = 0.0 var i = 0 - val len = weightsArray.length + val len = coefficientsArray.length while (i < len) { if (featuresStd(i) != 0.0) { - weightsArray(i) /= featuresStd(i) - sum += weightsArray(i) * featuresMean(i) + coefficientsArray(i) /= featuresStd(i) + sum += coefficientsArray(i) * featuresMean(i) } else { - weightsArray(i) = 0.0 + coefficientsArray(i) = 0.0 } i += 1 } - (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (coefficientsArray, offset, coefficientsArray.length) } - private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + private val effectiveCoefficientsVector = Vectors.dense(effectiveCoefficientsArray) private val gradientSumArray = Array.ofDim[Double](dim) @@ -520,30 +544,33 @@ private class LeastSquaresAggregator( * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient * of the objective function. * - * @param label The label for this data point. - * @param data The features for one data point in dense/sparse vector format to be added - * into this aggregator. + * @param instance The data point instance to be added. * @return This LeastSquaresAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${data.size}.") + def add(instance: Instance): this.type = + instance match { case Instance(label, weight, features) => + require(dim == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${features.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") - val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + if (weight == 0.0) return this - if (diff != 0) { - val localGradientSumArray = gradientSumArray - data.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += diff * value / featuresStd(index) + val diff = dot(features, effectiveCoefficientsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + features.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += weight * diff * value / featuresStd(index) + } } + lossSum += weight * diff * diff / 2.0 } - lossSum += diff * diff / 2.0 - } - totalCnt += 1 - this - } + totalCnt += 1 + weightSum += weight + this + } /** * Merge another LeastSquaresAggregator, and update the loss and gradient @@ -557,8 +584,9 @@ private class LeastSquaresAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { + if (other.weightSum != 0) { totalCnt += other.totalCnt + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -574,11 +602,17 @@ private class LeastSquaresAggregator( def count: Long = totalCnt - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -589,7 +623,7 @@ private class LeastSquaresAggregator( * It's used in Breeze's convex optimization routines. */ private class LeastSquaresCostFun( - data: RDD[(Double, Vector)], + data: RDD[Instance], labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -598,17 +632,13 @@ private class LeastSquaresCostFun( featuresMean: Array[Double], effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { - val w = Vectors.fromBreeze(weights) + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val coeff = Vectors.fromBreeze(coefficients) - val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(coeff, labelStd, labelMean, fitIntercept, featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + seqOp = (aggregator, instance) => aggregator.add(instance), + combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) val totalGradientArray = leastSquaresAggregator.gradient.toArray @@ -616,7 +646,7 @@ private class LeastSquaresCostFun( 0.0 } else { var sum = 0.0 - w.foreachActive { (index, value) => + coeff.foreachActive { (index, value) => // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. sum += { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2aaee71ecc73..8428f4f00b37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.ml.regression +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -510,4 +513,89 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .zip(testSummary.residuals.select("residuals").collect()) .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } } + + test("linear regression with weighted samples"){ + val (data, weightedData) = { + val activeData = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + + val rnd = new Random(8392) + val signedData = activeData.map { case p: LabeledPoint => + (rnd.nextGaussian() > 0.0, p) + } + + val data1 = signedData.flatMap { + case (true, p) => Iterator(p, p) + case (false, p) => Iterator(p) + } + + val weightedSignedData = signedData.flatMap { + case (true, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 1.2, features), + Instance(label, weight = 0.8, features) + ) + case (false, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 0.3, features), + Instance(label, weight = 0.1, features), + Instance(label, weight = 0.6, features) + ) + } + + val noiseData = LinearDataGenerator.generateLinearInput( + 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + val weightedNoiseData = noiseData.map { + case LabeledPoint(label, features) => Instance(label, weight = 0, features) + } + val data2 = weightedSignedData ++ weightedNoiseData + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model1a0 = trainer1a.fit(data) + val model1a1 = trainer1a.fit(weightedData) + val model1b = trainer1b.fit(weightedData) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + val trainer2a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model2a0 = trainer2a.fit(data) + val model2a1 = trainer2a.fit(weightedData) + val model2b = trainer2b.fit(weightedData) + assert(model2a0.weights !~= model2a1.weights absTol 1E-3) + assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) + assert(model2a0.weights ~== model2b.weights absTol 1E-3) + assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) + + val trainer3a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model3a0 = trainer3a.fit(data) + val model3a1 = trainer3a.fit(weightedData) + val model3b = trainer3b.fit(weightedData) + assert(model3a0.weights !~= model3a1.weights absTol 1E-3) + assert(model3a0.weights ~== model3b.weights absTol 1E-3) + + val trainer4a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model4a0 = trainer4a.fit(data) + val model4a1 = trainer4a.fit(weightedData) + val model4b = trainer4b.fit(weightedData) + assert(model4a0.weights !~= model4a1.weights absTol 1E-3) + assert(model4a0.weights ~== model4b.weights absTol 1E-3) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 814a11e588ce..b2e6be706637 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,10 +70,14 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") - ) ++ - Seq( + ) ++ Seq( ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this") ) case v if v.startsWith("1.5") => Seq( From b78c65b03ae87a3ba348c9d29ff4c296349eb49c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?hushan=5B=E8=83=A1=E7=8F=8A=5D?= Date: Mon, 21 Sep 2015 14:26:15 -0500 Subject: [PATCH 0700/1824] [SPARK-5259] [CORE] don't submit stage until its dependencies map outputs are registered MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Track pending tasks by partition ID instead of Task objects. Before this change, failure & retry could result in a case where a stage got submitted before the map output from its dependencies get registered. This was due to an error in the condition for registering map outputs. Author: hushan[胡çŠ] Author: Imran Rashid Closes #7699 from squito/SPARK-5259. --- .../apache/spark/scheduler/DAGScheduler.scala | 12 +- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../spark/scheduler/TaskSetManager.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 197 ++++++++++++++++-- 4 files changed, 191 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3c9a66e50440..394228b2728d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -944,7 +944,7 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingTasks.clear() + stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { @@ -1060,8 +1060,8 @@ class DAGScheduler( if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingTasks ++= tasks - logDebug("New pending tasks: " + stage.pendingTasks) + stage.pendingPartitions ++= tasks.map(_.partitionId) + logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) @@ -1152,7 +1152,7 @@ class DAGScheduler( case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) - stage.pendingTasks -= task + stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1198,7 +1198,7 @@ class DAGScheduler( shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -1242,7 +1242,7 @@ class DAGScheduler( case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - stage.pendingTasks += task + stage.pendingPartitions += task.partitionId case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index b37eccbd0f7b..a3829c319c48 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -66,7 +66,7 @@ private[scheduler] abstract class Stage( /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - var pendingTasks = new HashSet[Task[_]] + val pendingPartitions = new HashSet[Int] /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 62af9031b9f8..c02597c4365c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -487,8 +487,8 @@ private[spark] class TaskSetManager( // a good proxy to task serialization time. // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( - taskName, taskId, host, taskLocality, serializedTask.limit)) + logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + + s"$taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1c55f90ad9b4..6b5bcf0574de 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -479,8 +479,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), @@ -490,7 +490,7 @@ class DAGSchedulerSuite // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -782,8 +782,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -1035,6 +1035,173 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + /** + * This test runs a three stage job, with a fetch failure in stage 1. but during the retry, we + * have completions from both the first & second attempt of stage 1. So all the map output is + * available before we finish any task set for stage 1. We want to make sure that we don't + * submit stage 2 until the map output for stage 1 is registered + */ + test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + // things start out smoothly, stage 0 completes with no issues + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostA", shuffleMapRdd.partitions.length)) + )) + + // then one executor dies, and a task fails in stage 1 + runEvent(ExecutorLost("exec-hostA")) + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), + null, + null, + createFakeTaskInfo(), + null)) + + // so we resubmit stage 0, which completes happily + scheduler.resubmitFailedStages() + val stage0Resubmit = taskSets(2) + assert(stage0Resubmit.stageId == 0) + assert(stage0Resubmit.stageAttemptId === 1) + val task = stage0Resubmit.tasks(0) + assert(task.partitionId === 2) + runEvent(CompletionEvent( + task, + Success, + makeMapStatus("hostC", shuffleMapRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // now here is where things get tricky : we will now have a task set representing + // the second attempt for stage 1, but we *also* have some tasks for the first attempt for + // stage 1 still going + val stage1Resubmit = taskSets(3) + assert(stage1Resubmit.stageId == 1) + assert(stage1Resubmit.stageAttemptId === 1) + assert(stage1Resubmit.tasks.length === 3) + + // we'll have some tasks finish from the first attempt, and some finish from the second attempt, + // so that we actually have all stage outputs, though no attempt has completed all its + // tasks + runEvent(CompletionEvent( + taskSets(3).tasks(0), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + runEvent(CompletionEvent( + taskSets(3).tasks(1), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + // late task finish from the first attempt + runEvent(CompletionEvent( + taskSets(1).tasks(2), + Success, + makeMapStatus("hostB", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // What should happen now is that we submit stage 2. However, we might not see an error + // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But + // we can check some conditions. + // Note that the really important thing here is not so much that we submit stage 2 *immediately* + // but that we don't end up with some error from these interleaved completions. It would also + // be OK (though sub-optimal) if stage 2 simply waited until the resubmission of stage 1 had + // all its tasks complete + + // check that we have all the map output for stage 0 (it should have been there even before + // the last round of completions from stage 1, but just to double check it hasn't been messed + // up) and also the newly available stage 1 + val stageToReduceIdxs = Seq( + 0 -> (0 until 3), + 1 -> (0 until 1) + ) + for { + (stage, reduceIdxs) <- stageToReduceIdxs + reduceIdx <- reduceIdxs + } { + // this would throw an exception if the map status hadn't been registered + val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 2 has been submitted + assert(taskSets.size == 5) + val stage2TaskSet = taskSets(4) + assert(stage2TaskSet.stageId == 2) + assert(stage2TaskSet.stageAttemptId == 0) + } + + /** + * We lose an executor after completing some shuffle map tasks on it. Those tasks get + * resubmitted, and when they finish the job completes normally + */ + test("register map outputs correctly after ExecutorLost and task Resubmitted") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) + submit(reduceRdd, Array(0)) + + // complete some of the tasks from the first stage, on one host + runEvent(CompletionEvent( + taskSets(0).tasks(0), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + + // now that host goes down + runEvent(ExecutorLost("exec-hostA")) + + // so we resubmit those tasks + runEvent(CompletionEvent( + taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null)) + + // now complete everything on a different host + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)) + )) + + // now we should submit stage 1, and the map output from stage 0 should be registered + + // check that we have all the map output for stage 0 + (0 until reduceRdd.partitions.length).foreach { reduceIdx => + val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 1 has been submitted + assert(taskSets.size == 2) + val stage1TaskSet = taskSets(1) + assert(stage1TaskSet.stageId == 1) + assert(stage1TaskSet.stageAttemptId == 0) + } + /** * Makes sure that failures of stage used by multiple jobs are correctly handled. * @@ -1393,8 +1560,8 @@ class DAGSchedulerSuite // Submit a map stage by itself submitMapStage(shuffleDep) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) assert(results.size === 1) results.clear() assertDataStructuresEmpty() @@ -1407,7 +1574,7 @@ class DAGSchedulerSuite // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch // from, then TaskSet 3 will run the reduce stage scheduler.resubmitFailedStages() - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) results.clear() @@ -1452,8 +1619,8 @@ class DAGSchedulerSuite // Complete the first stage assert(taskSets(0).stageId === 0) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", rdd1.partitions.size)), - (Success, makeMapStatus("hostB", rdd1.partitions.size)))) + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) assert(listener1.results.size === 1) @@ -1461,7 +1628,7 @@ class DAGSchedulerSuite // When attempting the second stage, show a fetch failure assert(taskSets(1).stageId === 1) complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", rdd2.partitions.size)), + (Success, makeMapStatus("hostA", rdd2.partitions.length)), (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() assert(listener2.results.size === 0) // Second stage listener should not have a result yet @@ -1469,7 +1636,7 @@ class DAGSchedulerSuite // Stage 0 should now be running as task set 2; make its task succeed assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( - (Success, makeMapStatus("hostC", rdd2.partitions.size)))) + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -1477,8 +1644,8 @@ class DAGSchedulerSuite // Stage 1 should now be running as task set 3; make its first task succeed assert(taskSets(3).stageId === 1) complete(taskSets(3), Seq( - (Success, makeMapStatus("hostB", rdd2.partitions.size)), - (Success, makeMapStatus("hostD", rdd2.partitions.size)))) + (Success, makeMapStatus("hostB", rdd2.partitions.length)), + (Success, makeMapStatus("hostD", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) assert(listener2.results.size === 1) @@ -1494,7 +1661,7 @@ class DAGSchedulerSuite // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 assert(taskSets(5).stageId === 1) complete(taskSets(5), Seq( - (Success, makeMapStatus("hostE", rdd2.partitions.size)))) + (Success, makeMapStatus("hostE", rdd2.partitions.length)))) complete(taskSets(6), Seq( (Success, 53))) assert(listener3.results === Map(0 -> 52, 1 -> 53)) From ba882db6f43dd2bc05675133158e4664ed07030a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:06:23 -0700 Subject: [PATCH 0701/1824] [SPARK-9769] [ML] [PY] add python api for countvectorizermodel From JIRA: Add Python API, user guide and example for ml.feature.CountVectorizerModel Author: Holden Karau Closes #8561 from holdenk/SPARK-9769-add-python-api-for-countvectorizermodel. --- python/pyspark/ml/feature.py | 148 +++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92db8df80280..f41d72f87725 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,12 +26,13 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', - 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer', - 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', - 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] +__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', + 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', + 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', + 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', + 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -171,6 +172,141 @@ def getSplits(self): return self.getOrDefault(self.splits) +@inherit_doc +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. + >>> df = sqlContext.createDataFrame( + ... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], + ... ["label", "raw"]) + >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors") + >>> model = cv.fit(df) + >>> model.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... + >>> sorted(map(str, model.vocabulary)) + ['a', 'b', 'c'] + """ + + # a placeholder to make it appear in the generated doc + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0") + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0") + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.") + + @keyword_only + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + """ + super(CountVectorizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", + self.uid) + self.minTF = Param( + self, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then " + + "this specifies a fraction (out of the document's token count). Note that the " + + "parameter is only used in transform of CountVectorizerModel and does not affect" + + "fitting. Default 1.0") + self.minDF = Param( + self, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of " + + "documents. Default 1.0") + self.vocabSize = Param( + self, "vocabSize", "max size of the vocabulary. Default 1 << 18.") + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + Set the params for the CountVectorizer + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + self._paramMap[self.minTF] = value + return self + + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + def setMinDF(self, value): + """ + Sets the value of :py:attr:`minDF`. + """ + self._paramMap[self.minDF] = value + return self + + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + def setVocabSize(self, value): + """ + Sets the value of :py:attr:`vocabSize`. + """ + self._paramMap[self.vocabSize] = value + return self + + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + def _create_model(self, java_model): + return CountVectorizerModel(java_model) + + +class CountVectorizerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by CountVectorizer. + """ + + @property + def vocabulary(self): + """ + An array of terms in the vocabulary. + """ + return self._call_java("vocabulary") + + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol): """ From aeef44a3e32b53f7adecc8e9cfd684fb4598e87d Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 21 Sep 2015 13:11:28 -0700 Subject: [PATCH 0702/1824] [SPARK-3147] [MLLIB] [STREAMING] Streaming 2-sample statistical significance testing Implementation of significance testing using Streaming API. Author: Feynman Liang Author: Feynman Liang Closes #4716 from feynmanliang/ab_testing. --- .../examples/mllib/StreamingTestExample.scala | 90 +++++++ .../spark/mllib/stat/test/StreamingTest.scala | 145 +++++++++++ .../mllib/stat/test/StreamingTestMethod.scala | 167 ++++++++++++ .../spark/mllib/stat/test/TestResult.scala | 22 ++ .../spark/mllib/stat/StreamingTestSuite.scala | 243 ++++++++++++++++++ 5 files changed, 667 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala new file mode 100644 index 000000000000..ab29f90254d3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.mllib.stat.test.StreamingTest +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.util.Utils + +/** + * Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data + * stream arrives as text files in a directory. Stops when the two groups are statistically + * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded. + * + * The rows of the text files must be in the form `Boolean, Double`. For example: + * false, -3.92 + * true, 99.32 + * + * Usage: + * StreamingTestExample + * + * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and + * a timeout after 100 insignificant batches, call: + * $ bin/run-example mllib.StreamingTestExample dataDir 5 100 + * + * As you add text files to `dataDir` the significance test wil continually update every + * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of + * batches processed exceeds `numBatchesTimeout`. + */ +object StreamingTestExample { + + def main(args: Array[String]) { + if (args.length != 3) { + // scalastyle:off println + System.err.println( + "Usage: StreamingTestExample " + + " ") + // scalastyle:on println + System.exit(1) + } + val dataDir = args(0) + val batchDuration = Seconds(args(1).toLong) + val numBatchesTimeout = args(2).toInt + + val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample") + val ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint({ + val dir = Utils.createTempDir() + dir.toString + }) + + val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { + case Array(label, value) => (label.toBoolean, value.toDouble) + }) + + val streamingTest = new StreamingTest() + .setPeacePeriod(0) + .setWindowSize(0) + .setTestMethod("welch") + + val out = streamingTest.registerStream(data) + out.print() + + // Stop processing if test becomes significant or we time out + var timeoutCounter = numBatchesTimeout + out.foreachRDD { rdd => + timeoutCounter -= 1 + val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _) + if (timeoutCounter == 0 || anySignificant) rdd.context.stop() + } + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala new file mode 100644 index 000000000000..75c6a51d0957 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * :: Experimental :: + * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The + * Boolean identifies which sample each observation comes from, and the Double is the numeric value + * of the observation. + * + * To address novelty affects, the `peacePeriod` specifies a set number of initial + * [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing. + * + * The `windowSize` sets the number of batches each significance test is to be performed over. The + * window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform + * cumulative processing, using all batches seen so far. + * + * Different tests may be used for assessing statistical significance depending on assumptions + * satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies + * which test will be used. + * + * Use a builder pattern to construct a streaming test in an application, for example: + * {{{ + * val model = new StreamingTest() + * .setPeacePeriod(10) + * .setWindowSize(0) + * .setTestMethod("welch") + * .registerStream(DStream) + * }}} + */ +@Experimental +@Since("1.6.0") +class StreamingTest @Since("1.6.0") () extends Logging with Serializable { + private var peacePeriod: Int = 0 + private var windowSize: Int = 0 + private var testMethod: StreamingTestMethod = WelchTTest + + /** Set the number of initial batches to ignore. Default: 0. */ + @Since("1.6.0") + def setPeacePeriod(peacePeriod: Int): this.type = { + this.peacePeriod = peacePeriod + this + } + + /** + * Set the number of batches to compute significance tests over. Default: 0. + * A value of 0 will use all batches seen so far. + */ + @Since("1.6.0") + def setWindowSize(windowSize: Int): this.type = { + this.windowSize = windowSize + this + } + + /** Set the statistical method used for significance testing. Default: "welch" */ + @Since("1.6.0") + def setTestMethod(method: String): this.type = { + this.testMethod = StreamingTestMethod.getTestMethodFromName(method) + this + } + + /** + * Register a [[DStream]] of values for significance testing. + * + * @param data stream of (key,value) pairs where the key denotes group membership (true = + * experiment, false = control) and the value is the numerical metric to test for + * significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { + val dataAfterPeacePeriod = dropPeacePeriod(data) + val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) + val pairedSummaries = pairSummaries(summarizedData) + + testMethod.doTest(pairedSummaries) + } + + /** Drop all batches inside the peace period. */ + private[stat] def dropPeacePeriod( + data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data.transform { (rdd, time) => + if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { + rdd + } else { + data.context.sparkContext.parallelize(Seq()) + } + } + } + + /** Compute summary statistics over each key and the specified test window size. */ + private[stat] def summarizeByKeyAndWindow( + data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + if (this.windowSize == 0) { + data.updateStateByKey[StatCounter]( + (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { + val newSummary = oldSummary.getOrElse(new StatCounter()) + newSummary.merge(newValues) + Some(newSummary) + }) + } else { + val windowDuration = data.slideDuration * this.windowSize + data + .groupByKeyAndWindow(windowDuration) + .mapValues { values => + val summary = new StatCounter() + values.foreach(value => summary.merge(value)) + summary + } + } + } + + /** + * Transform a stream of summaries into pairs representing summary statistics for control group + * and experiment group up to this batch. + */ + private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)]) + : DStream[(StatCounter, StatCounter)] = { + summarizedData + .map[(Int, StatCounter)](x => (0, x._2)) + .groupByKey() // should be length two (control/experiment group) + .map(x => (x._2.head, x._2.last)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala new file mode 100644 index 000000000000..a7eaed51b4d5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import java.io.Serializable + +import scala.language.implicitConversions +import scala.math.pow + +import com.twitter.chill.MeatLocker +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues +import org.apache.commons.math3.stat.inference.TTest + +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests + * should extend [[StreamingTestMethod]] and introduce a new entry in + * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]] + */ +private[stat] sealed trait StreamingTestMethod extends Serializable { + + val methodName: String + val nullHypothesis: String + + protected type SummaryPairStream = + DStream[(StatCounter, StatCounter)] + + /** + * Perform streaming 2-sample statistical significance testing. + * + * @param sampleSummaries stream pairs of summary statistics for the 2 samples + * @return stream of rest results + */ + def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult] + + /** + * Implicit adapter to convert between streaming summary statistics type and the type required by + * the t-testing libraries. + */ + protected implicit def toApacheCommonsStats( + summaryStats: StatCounter): StatisticalSummaryValues = { + new StatisticalSummaryValues( + summaryStats.mean, + summaryStats.variance, + summaryStats.count, + summaryStats.max, + summaryStats.min, + summaryStats.mean * summaryStats.count + ) + } +} + +/** + * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean. + * This test does not assume equal variance between the two samples and does not assume equal + * sample size. + * + * @see http://en.wikipedia.org/wiki/Welch%27s_t_test + */ +private[stat] object WelchTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Welch's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = { + val s1 = sample1.getVariance + val n1 = sample1.getN + val s2 = sample2.getVariance + val n2 = sample2.getN + + val a = pow(s1, 2) / n1 + val b = pow(s2, 2) / n2 + + pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1))) + } + + new StreamingTestResult( + tTester.get.tTest(statsA, statsB), + welchDF(statsA, statsB), + tTester.get.t(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal + * mean. This test assumes equal variance between the two samples and does not assume equal sample + * size. For unequal variances, Welch's t-test should be used instead. + * + * @see http://en.wikipedia.org/wiki/Student%27s_t-test + */ +private[stat] object StudentTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Student's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = + sample1.getN + sample2.getN - 2 + + new StreamingTestResult( + tTester.get.homoscedasticTTest(statsA, statsB), + studentDF(statsA, statsB), + tTester.get.homoscedasticT(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between + * strings used in [[StreamingTest]] configuration and actual method implementation. + * + * Currently supported tests: `welch`, `student`. + */ +private[stat] object StreamingTestMethod { + // Note: after new `StreamingTestMethod`s are implemented, please update this map. + private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( + "welch"->WelchTTest, + "student"->StudentTTest) + + def getTestMethodFromName(method: String): StreamingTestMethod = + TEST_NAME_TO_OBJECT.get(method) match { + case Some(test) => test + case None => + throw new IllegalArgumentException( + "Unrecognized method name. Supported streaming test methods: " + + TEST_NAME_TO_OBJECT.keys.mkString(", ")) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index d01b3707be94..b0916d3e8465 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -115,3 +115,25 @@ class KolmogorovSmirnovTestResult private[stat] ( "Kolmogorov-Smirnov test summary:\n" + super.toString } } + +/** + * :: Experimental :: + * Object containing the test results for streaming testing. + */ +@Experimental +@Since("1.6.0") +private[stat] class StreamingTestResult @Since("1.6.0") ( + @Since("1.6.0") override val pValue: Double, + @Since("1.6.0") override val degreesOfFreedom: Double, + @Since("1.6.0") override val statistic: Double, + @Since("1.6.0") val method: String, + @Since("1.6.0") override val nullHypothesis: String) + extends TestResult[Double] with Serializable { + + override def toString: String = { + "Streaming test summary:\n" + + s"method: $method\n" + + super.toString + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala new file mode 100644 index 000000000000..d3e9ef4ff079 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter +import org.apache.spark.util.random.XORShiftRandom + +class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { + + override def maxWaitTimeMillis : Int = 30000 + + test("accuracy for null hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for alternative hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for null hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == StudentTTest.methodName)) + } + + test("accuracy for alternative hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == StudentTTest.methodName)) + } + + test("batches within same test window are grouped") { + // set parameters + val testWindow = 3 + val numBatches = 5 + val pointsPerBatch = 100 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(testWindow) + .setPeacePeriod(0) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, + (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) + val outputCounts = outputBatches.flatten.map(_._2.count) + + // number of batches seen so far does not exceed testWindow, expect counts to continue growing + for (i <- 0 until testWindow) { + assert(outputCounts.drop(2 * i).take(2).forall(_ == (i + 1) * pointsPerBatch / 2)) + } + + // number of batches seen exceeds testWindow, expect counts to be constant + assert(outputCounts.drop(2 * (testWindow - 1)).forall(_ == testWindow * pointsPerBatch / 2)) + } + + + test("entries in peace period are dropped") { + // set parameters + val peacePeriod = 3 + val numBatches = 7 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(peacePeriod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) + } + + test("null hypothesis when only data from one group is present") { + // set parameters + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + + val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + .map(batch => batch.filter(_._1)) // only keep one test group + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) + } + + // Generate testing input with half of the entries in group A and half in group B + private def generateTestData( + numBatches: Int, + pointsPerBatch: Int, + meanA: Double, + stdevA: Double, + meanB: Double, + stdevB: Double, + seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + val rand = new XORShiftRandom(seed) + val numTrues = pointsPerBatch / 2 + val data = (0 until numBatches).map { i => + (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (pointsPerBatch / 2 until pointsPerBatch).map { idx => + (false, meanB + stdevB * rand.nextGaussian()) + } + } + + data + } +} From 97a99dde6e8d69a4c4c135dc1d9b1520b2548b5b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 21 Sep 2015 13:15:44 -0700 Subject: [PATCH 0703/1824] [SPARK-10676] [DOCS] Add documentation for SASL encryption options. Author: Marcelo Vanzin Closes #8803 from vanzin/SPARK-10676. --- docs/configuration.md | 16 ++++++++++++++++ docs/security.md | 22 ++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index b22587c70316..284f97ad09ec 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1285,6 +1285,22 @@ Apart from these, the following properties are also available, and may be useful not running on YARN and authentication is enabled. + + spark.authenticate.enableSaslEncryption + false + + Enable encrypted communication when authentication is enabled. This option is currently + only supported by the block transfer service. + + + + spark.network.sasl.serverAlwaysEncrypt + false + + Disable unencrypted connections for services that support SASL authentication. This is + currently supported by the external shuffle service. + + spark.core.connection.ack.wait.timeout 60s diff --git a/docs/security.md b/docs/security.md index d4ffa60e59a3..177109415180 100644 --- a/docs/security.md +++ b/docs/security.md @@ -23,9 +23,16 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. However SSL is not supported yet for WebUI and block transfer service. +Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. SASL encryption is +supported for the block transfer service. Encryption is not yet supported for the WebUI. -Connection encryption (SSL) configuration is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle +files, cached data, and other application files. If encrypting this data is desired, a workaround is +to configure your cluster manager to store application data on encrypted disks. + +### SSL Configuration + +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. @@ -47,6 +54,17 @@ follows: * Import all exported public keys into a single trust-store * Distribute the trust-store over the nodes +### Configuring SASL Encryption + +SASL encryption is currently supported for the block transfer service when authentication +(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set +`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. + +When using an external shuffle service, it's possible to disable unencrypted connections by setting +`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that +option is enabled, applications that are not set up to use SASL encryption will fail to connect to +the shuffle service. + ## Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight From 362539f8d97f6bb67f0d0983f7dea36b77cc9d18 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:33:10 -0700 Subject: [PATCH 0704/1824] [SPARK-10630] [SQL] Add a createDataFrame API that takes in a java list It would be nice to support creating a DataFrame directly from a Java List of Row. Author: Holden Karau Closes #8779 from holdenk/SPARK-10630-create-DataFrame-from-Java-List. --- .../scala/org/apache/spark/sql/SQLContext.scala | 14 ++++++++++++++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f099940800cc..1bd4e26fb316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -476,6 +476,20 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rowRDD.rdd, schema) } + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided List matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + * @since 1.6.0 + */ + @DeveloperApi + def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { + DataFrame(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + } + /** * Applies a schema to an RDD of Java Beans. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 5f9abd4999ce..250ac2e1092d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -37,6 +37,7 @@ import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -181,6 +182,15 @@ public void testCreateDataFrameFromJavaBeans() { } } + @Test + public void testCreateDataFromFromList() { + StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); + List rows = Arrays.asList(RowFactory.create(0)); + DataFrame df = context.createDataFrame(rows, schema); + Row[] result = df.collect(); + Assert.assertEquals(1, result.length); + } + private static final Comparator crosstabRowComparator = new Comparator() { @Override public int compare(Row row1, Row row2) { From 7c4f852bfc39537840f56cd8121457a0dc1ad7c1 Mon Sep 17 00:00:00 2001 From: noelsmith Date: Mon, 21 Sep 2015 14:24:19 -0700 Subject: [PATCH 0705/1824] [DOC] [PYSPARK] [MLLIB] Added newlines to docstrings to fix parameter formatting Added newlines before `:param ...:` and `:return:` markup. Without these, parameter lists aren't formatted correctly in the API docs. I.e: ![screen shot 2015-09-21 at 21 49 26](https://cloud.githubusercontent.com/assets/11915197/10004686/de3c41d4-60aa-11e5-9c50-a46dcb51243f.png) .. looks like this once newline is added: ![screen shot 2015-09-21 at 21 50 14](https://cloud.githubusercontent.com/assets/11915197/10004706/f86bfb08-60aa-11e5-8524-ae4436713502.png) Author: noelsmith Closes #8851 from noel-smith/docstring-missing-newline-fix. --- python/pyspark/ml/param/__init__.py | 4 ++++ python/pyspark/ml/pipeline.py | 1 + python/pyspark/ml/tuning.py | 2 ++ python/pyspark/ml/wrapper.py | 2 ++ python/pyspark/mllib/evaluation.py | 2 +- python/pyspark/mllib/linalg/__init__.py | 1 + python/pyspark/streaming/context.py | 2 ++ python/pyspark/streaming/mqtt.py | 1 + 8 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index eeeac49b2198..2e0c63cb47b1 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -164,6 +164,7 @@ def extractParamMap(self, extra=None): a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra. + :param extra: extra param values :return: merged param map """ @@ -182,6 +183,7 @@ def copy(self, extra=None): embedded and extra parameters over and returns the copy. Subclasses should override this method if the default approach is not sufficient. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -201,6 +203,7 @@ def _shouldOwn(self, param): def _resolveParam(self, param): """ Resolves a param and validates the ownership. + :param param: param name or the param instance, which must belong to this Params instance :return: resolved param instance @@ -243,6 +246,7 @@ def _copyValues(self, to, extra=None): """ Copies param values from this instance to another instance for params shared by them. + :param to: the target instance :param extra: extra params to be copied :return: the target instance with param values copied diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 13cf2b0f7bbd..312a8502b3a2 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -154,6 +154,7 @@ def __init__(self, stages=None): def setStages(self, value): """ Set pipeline stages. + :param value: a list of transformers or estimators :return: the pipeline instance """ diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ab5621f45c72..705ee5368575 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -254,6 +254,7 @@ def copy(self, extra=None): Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -290,6 +291,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 8218c7c5f801..4bcb4aaec89d 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -119,6 +119,7 @@ def _create_model(self, java_model): def _fit_java(self, dataset): """ Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` :param params: additional params (overwriting embedded values) @@ -173,6 +174,7 @@ def copy(self, extra=None): extra params. This implementation first calls Params.copy and then make a copy of the companion Java model with extra params. So both the Python wrapper and the Java model get copied. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 4398ca86f2ec..a90e5c50e54b 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -147,7 +147,7 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. - :param predictionAndLabels an RDD of (prediction, label) pairs. + :param predictionAndLabels: an RDD of (prediction, label) pairs. >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index f929e3e96fbe..ea42127f1651 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -240,6 +240,7 @@ class Vector(object): def toArray(self): """ Convert the vector into an numpy.ndarray + :return: numpy.ndarray """ raise NotImplementedError diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 4069d7a14998..a8c9ffc235b9 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -240,6 +240,7 @@ def start(self): def awaitTermination(self, timeout=None): """ Wait for the execution to stop. + @param timeout: time to wait in seconds """ if timeout is None: @@ -252,6 +253,7 @@ def awaitTerminationOrTimeout(self, timeout): Wait for the execution to stop. Return `true` if it's stopped; or throw the reported error during the execution; or `false` if the waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds """ self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index f06598971c54..fa83006c36db 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -31,6 +31,7 @@ def createStream(ssc, brokerUrl, topic, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): """ Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object :param brokerUrl: Url of remote mqtt publisher :param topic: topic name to subscribe to From 72869883f12b6e0a4e5aad79c0ac2cfdb4d83f09 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Sep 2015 16:47:52 -0700 Subject: [PATCH 0706/1824] [SPARK-10649] [STREAMING] Prevent inheriting job group and irrelevant job description in streaming jobs The job group, and job descriptions information is passed through thread local properties, and get inherited by child threads. In case of spark streaming, the streaming jobs inherit these properties from the thread that called streamingContext.start(). This may not make sense. 1. Job group: This is mainly used for cancelling a group of jobs together. It does not make sense to cancel streaming jobs like this, as the effect will be unpredictable. And its not a valid usecase any way, to cancel a streaming context, call streamingContext.stop() 2. Job description: This is used to pass on nice text descriptions for jobs to show up in the UI. The job description of the thread that calls streamingContext.start() is not useful for all the streaming jobs, as it does not make sense for all of the streaming jobs to have the same description, and the description may or may not be related to streaming. The solution in this PR is meant for the Spark master branch, where local properties are inherited by cloning the properties. The job group and job description in the thread that starts the streaming scheduler are explicitly removed, so that all the subsequent child threads does not inherit them. Also, the starting is done in a new child thread, so that setting the job group and description for streaming, does not change those properties in the thread that called streamingContext.start(). Author: Tathagata Das Closes #8781 from tdas/SPARK-10649. --- .../org/apache/spark/util/ThreadUtils.scala | 59 +++++++++++++++++++ .../apache/spark/util/ThreadUtilsSuite.scala | 24 +++++++- .../spark/streaming/StreamingContext.scala | 15 ++++- .../streaming/StreamingContextSuite.scala | 32 ++++++++++ 4 files changed, 126 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index ca5624a3d8b3..22e291a2b48d 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -21,6 +21,7 @@ package org.apache.spark.util import java.util.concurrent._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -86,4 +87,62 @@ private[spark] object ThreadUtils { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() Executors.newSingleThreadScheduledExecutor(threadFactory) } + + /** + * Run a piece of code in a new thread and return the result. Exception in the new thread is + * thrown in the caller thread with an adjusted stack trace that removes references to this + * method for clarity. The exception stack traces will be like the following + * + * SomeException: exception-message + * at CallerClass.body-method (sourcefile.scala) + * at ... run in separate thread using org.apache.spark.util.ThreadUtils ... () + * at CallerClass.caller-method (sourcefile.scala) + * ... + */ + def runInNewThread[T]( + threadName: String, + isDaemon: Boolean = true)(body: => T): T = { + @volatile var exception: Option[Throwable] = None + @volatile var result: T = null.asInstanceOf[T] + + val thread = new Thread(threadName) { + override def run(): Unit = { + try { + result = body + } catch { + case NonFatal(e) => + exception = Some(e) + } + } + } + thread.setDaemon(isDaemon) + thread.start() + thread.join() + + exception match { + case Some(realException) => + // Remove the part of the stack that shows method calls into this helper method + // This means drop everything from the top until the stack element + // ThreadUtils.runInNewThread(), and then drop that as well (hence the `drop(1)`). + val baseStackTrace = Thread.currentThread().getStackTrace().dropWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)).drop(1) + + // Remove the part of the new thread stack that shows methods call from this helper method + val extraStackTrace = realException.getStackTrace.takeWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)) + + // Combine the two stack traces, with a place holder just specifying that there + // was a helper method used, without any further details of the helper + val placeHolderStackElem = new StackTraceElement( + s"... run in separate thread using ${ThreadUtils.getClass.getName.stripSuffix("$")} ..", + " ", "", -1) + val finalStackTrace = extraStackTrace ++ Seq(placeHolderStackElem) ++ baseStackTrace + + // Update the stack trace and rethrow the exception in the caller thread + realException.setStackTrace(finalStackTrace) + throw realException + case None => + result + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 8c51e6b14b7f..620e4debf4e0 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.{Await, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} +import scala.util.Random import org.apache.spark.SparkFunSuite @@ -66,4 +67,25 @@ class ThreadUtilsSuite extends SparkFunSuite { val futureThreadName = Await.result(f, 10.seconds) assert(futureThreadName === callerThreadName) } + + test("runInNewThread") { + import ThreadUtils._ + assert(runInNewThread("thread-name") { Thread.currentThread().getName } === "thread-name") + assert(runInNewThread("thread-name") { Thread.currentThread().isDaemon } === true) + assert( + runInNewThread("thread-name", isDaemon = false) { Thread.currentThread().isDaemon } === false + ) + val uniqueExceptionMessage = "test" + Random.nextInt() + val exception = intercept[IllegalArgumentException] { + runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) } + } + assert(exception.asInstanceOf[IllegalArgumentException].getMessage === uniqueExceptionMessage) + assert(exception.getStackTrace.mkString("\n").contains( + "... run in separate thread using org.apache.spark.util.ThreadUtils ...") === true, + "stack trace does not contain expected place holder" + ) + assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false, + "stack trace contains unexpected references to ThreadUtils" + ) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index b496d1f341a0..6720ba4f72cf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -588,12 +588,20 @@ class StreamingContext private[streaming] ( state match { case INITIALIZED => startSite.set(DStream.getCreationSite()) - sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() try { validate() - scheduler.start() + + // Start the streaming scheduler in a new thread, so that thread local properties + // like call sites and job groups can be reset without affecting those of the + // current thread. + ThreadUtils.runInNewThread("streaming-start") { + sparkContext.setCallSite(startSite.get) + sparkContext.clearJobGroup() + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + scheduler.start() + } state = StreamingContextState.ACTIVE } catch { case NonFatal(e) => @@ -618,6 +626,7 @@ class StreamingContext private[streaming] ( } } + /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index d26894e88fc2..3b9d0d15ea04 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -180,6 +180,38 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.scheduler.isStarted === false) } + test("start should set job group and description of streaming jobs correctly") { + ssc = new StreamingContext(conf, batchDuration) + ssc.sc.setJobGroup("non-streaming", "non-streaming", true) + val sc = ssc.sc + + @volatile var jobGroupFound: String = "" + @volatile var jobDescFound: String = "" + @volatile var jobInterruptFound: String = "" + @volatile var allFound: Boolean = false + + addInputStream(ssc).foreachRDD { rdd => + jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) + jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + allFound = true + } + ssc.start() + + eventually(timeout(10 seconds), interval(10 milliseconds)) { + assert(allFound === true) + } + + // Verify streaming jobs have expected thread-local properties + assert(jobGroupFound === null) + assert(jobDescFound === null) + assert(jobInterruptFound === "false") + + // Verify current thread's thread-local properties have not changed + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true") + } test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) From 0494c80ef54f6f3a8c6f2d92abfe1a77a91df8b0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 21 Sep 2015 18:06:45 -0700 Subject: [PATCH 0707/1824] [SPARK-10495] [SQL] Read date values in JSON data stored by Spark 1.5.0. https://issues.apache.org/jira/browse/SPARK-10681 Author: Yin Huai Closes #8806 from yhuai/SPARK-10495. --- .../datasources/json/JacksonGenerator.scala | 36 ++++++ .../datasources/json/JacksonParser.scala | 15 ++- .../datasources/json/JsonSuite.scala | 103 +++++++++++++++++- 3 files changed, 152 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index f65c7bbd6e29..23bada1ddd92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -73,6 +73,38 @@ private[sql] object JacksonGenerator { valWriter(field.dataType, v) } gen.writeEndObject() + + // For UDT, udt.serialize will produce SQL types. So, we need the following three cases. + case (ArrayType(ty, _), v: ArrayData) => + gen.writeStartArray() + v.foreach(ty, (_, value) => valWriter(ty, value)) + gen.writeEndArray() + + case (MapType(kt, vt, _), v: MapData) => + gen.writeStartObject() + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) + gen.writeEndObject() + + case (StructType(ty), v: InternalRow) => + gen.writeStartObject() + var i = 0 + while (i < ty.length) { + val field = ty(i) + val value = v.get(i, field.dataType) + if (value != null) { + gen.writeFieldName(field.name) + valWriter(field.dataType, value) + } + i += 1 + } + gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) @@ -133,6 +165,10 @@ private[sql] object JacksonGenerator { i += 1 } gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index ff4d8c04e8ea..c51140749c8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -62,10 +62,23 @@ private[sql] object JacksonParser { // guard the non string type null + case (VALUE_STRING, BinaryType) => + parser.getBinaryValue + case (VALUE_STRING, DateType) => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + val stringValue = parser.getText + if (stringValue.contains("-")) { + // The format of this string will probably be "yyyy-mm-dd". + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + } else { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } case (VALUE_STRING, TimestampType) => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6a18cc6d2713..b614e6c4148f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType @@ -1159,4 +1159,105 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } + + test("backward compatibility") { + // This test we make sure our JSON support can read JSON data generated by previous version + // of Spark generated through toJSON method and JSON data source. + // The data is generated by the following program. + // Here are a few notes: + // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) + // in the JSON object. + // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to + // JSON objects generated by those Spark versions (col17). + // - If the type is NullType, we do not write data out. + + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + val constantValues = + Seq( + "a string in binary".getBytes("UTF-8"), + null, + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75, + new java.math.BigDecimal(s"1234.23456"), + new java.math.BigDecimal(s"1.23456"), + java.sql.Date.valueOf("2015-01-01"), + java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), + Seq(2, 3, 4), + Map("a string" -> 2000L), + Row(4.75.toFloat, Seq(false, true)), + new MyDenseVector(Array(0.25, 2.25, 4.25))) + val data = + Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + + // Data generated by previous versions. + // scalastyle:off + val existingJSONData = + """{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil + // scalastyle:on + + // Generate data for the current version. + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + withTempPath { path => + df.write.format("json").mode("overwrite").save(path.getCanonicalPath) + + // df.toJSON will convert internal rows to external rows first and then generate + // JSON objects. While, df.write.format("json") will write internal rows directly. + val allJSON = + existingJSONData ++ + df.toJSON.collect() ++ + sparkContext.textFile(path.getCanonicalPath).collect() + + Utils.deleteRecursively(path) + sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) + + // Read data back with the schema specified. + val col0Values = + Seq( + "Spark 1.2.2", + "Spark 1.3.1", + "Spark 1.3.1", + "Spark 1.4.1", + "Spark 1.4.1", + "Spark 1.5.0", + "Spark 1.5.0", + "Spark " + sqlContext.sparkContext.version, + "Spark " + sqlContext.sparkContext.version) + val expectedResult = col0Values.map { v => + Row.fromSeq(Seq(v) ++ constantValues) + } + checkAnswer( + sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + expectedResult + ) + } + } } From c986e933a900602af47966bd41edb2116c421a39 Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 21 Sep 2015 21:09:59 -0700 Subject: [PATCH 0708/1824] [SPARK-10711] [SPARKR] Do not assume spark.submit.deployMode is always set In ```RUtils.sparkRPackagePath()``` we 1. Call ``` sys.props("spark.submit.deployMode")``` which returns null if ```spark.submit.deployMode``` is not suet 2. Call ``` sparkConf.get("spark.submit.deployMode")``` which throws ```NoSuchElementException``` if ```spark.submit.deployMode``` is not set. This patch simply passes a default value ("cluster") for ```spark.submit.deployMode```. cc rxin Author: Hossein Closes #8832 from falaki/SPARK-10711. --- core/src/main/scala/org/apache/spark/api/r/RUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 9e807cc52f18..fd5646b5b637 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -44,7 +44,7 @@ private[spark] object RUtils { (sys.props("spark.master"), sys.props("spark.submit.deployMode")) } else { val sparkConf = SparkEnv.get.conf - (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode", "client")) } val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster" From 1cd67415728e660a90e4dbe136272b5d6b8f1142 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 23:21:24 -0700 Subject: [PATCH 0709/1824] [SPARK-9821] [PYSPARK] pyspark-reduceByKey-should-take-a-custom-partitioner from the issue: In Scala, I can supply a custom partitioner to reduceByKey (and other aggregation/repartitioning methods like aggregateByKey and combinedByKey), but as far as I can tell from the Pyspark API, there's no way to do the same in Python. Here's an example of my code in Scala: weblogs.map(s => (getFileType(s), 1)).reduceByKey(new FileTypePartitioner(),_+_) But I can't figure out how to do the same in Python. The closest I can get is to call repartition before reduceByKey like so: weblogs.map(lambda s: (getFileType(s), 1)).partitionBy(3,hash_filetype).reduceByKey(lambda v1,v2: v1+v2).collect() But that defeats the purpose, because I'm shuffling twice instead of once, so my performance is worse instead of better. Author: Holden Karau Closes #8569 from holdenk/SPARK-9821-pyspark-reduceByKey-should-take-a-custom-partitioner. --- python/pyspark/rdd.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 73d7d9a5692a..56e892243c79 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -686,7 +686,7 @@ def cartesian(self, other): other._jrdd_deserializer) return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) - def groupBy(self, f, numPartitions=None): + def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): """ Return an RDD of grouped items. @@ -695,7 +695,7 @@ def groupBy(self, f, numPartitions=None): >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) + return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, partitionFunc) @ignore_unicode_prefix def pipe(self, command, env=None, checkCode=False): @@ -1539,22 +1539,23 @@ def values(self): """ return self.map(lambda x: x[1]) - def reduceByKey(self, func, numPartitions=None): + def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative reduce function. This will also perform the merging locally on each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. - Output will be hash-partitioned with C{numPartitions} partitions, or + Output will be partitioned with C{numPartitions} partitions, or the default parallelism level if C{numPartitions} is not specified. + Default partitioner is hash-partition. >>> from operator import add >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda x: x, func, func, numPartitions) + return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc) def reduceByKeyLocally(self, func): """ @@ -1739,7 +1740,7 @@ def add_shuffle_key(split, iterator): # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numPartitions=None): + numPartitions=None, partitionFunc=portable_hash): """ Generic function to combine the elements for each key using a custom set of aggregation functions. @@ -1777,7 +1778,7 @@ def combineLocally(iterator): return merger.items() locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) @@ -1786,7 +1787,8 @@ def _mergeCombiners(iterator): return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) - def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None, + partitionFunc=portable_hash): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". This function can return a different result type, U, than the type @@ -1800,9 +1802,9 @@ def createZero(): return copy.deepcopy(zeroValue) return self.combineByKey( - lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) + lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions, partitionFunc) - def foldByKey(self, zeroValue, func, numPartitions=None): + def foldByKey(self, zeroValue, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative function "func" and a neutral "zeroValue" which may be added to the result an @@ -1817,13 +1819,14 @@ def foldByKey(self, zeroValue, func, numPartitions=None): def createZero(): return copy.deepcopy(zeroValue) - return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions, + partitionFunc) def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) # TODO: support variant with custom partitioner - def groupByKey(self, numPartitions=None): + def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): """ Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. @@ -1859,7 +1862,7 @@ def combine(iterator): return merger.items() locally_combined = self.mapPartitions(combine, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def groupByKey(it): merger = ExternalGroupBy(agg, memory, serializer) From bf20d6c9f9e478a5de24b45bbafd4dd89666c4cf Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 21 Sep 2015 23:29:59 -0700 Subject: [PATCH 0710/1824] [SPARK-10716] [BUILD] spark-1.5.0-bin-hadoop2.6.tgz file doesn't uncompress on OS X due to hidden file Remove ._SUCCESS.crc hidden file that may cause problems in distribution tar archive, and is not used Author: Sean Owen Closes #8846 from srowen/SPARK-10716. --- .../test_support/sql/orc_partitioned/._SUCCESS.crc | Bin 8 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 python/test_support/sql/orc_partitioned/._SUCCESS.crc diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc deleted file mode 100644 index 3b7b044936a890cd8d651d349a752d819d71d22c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8 PcmYc;N@ieSU}69O2$TUk From 0180b849dbaf191826231eda7dfaaf146a19602b Mon Sep 17 00:00:00 2001 From: Jian Feng Date: Mon, 21 Sep 2015 23:36:41 -0700 Subject: [PATCH 0711/1824] [SPARK-10577] [PYSPARK] DataFrame hint for broadcast join https://issues.apache.org/jira/browse/SPARK-10577 Author: Jian Feng Closes #8801 from Jianfeng-chs/master. --- python/pyspark/sql/functions.py | 9 +++++++++ python/pyspark/sql/tests.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 26b8662718a6..fa04f4cd83b6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -29,6 +29,7 @@ from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.dataframe import DataFrame def _create_function(name, doc=""): @@ -189,6 +190,14 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +@since(1.6) +def broadcast(df): + """Marks a DataFrame as small enough for use in broadcast joins.""" + + sc = SparkContext._active_spark_context + return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx) + + @since(1.4) def coalesce(*cols): """Returns the first column that is not null. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3e680f1030a7..645133b2b2d8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1075,6 +1075,24 @@ def foo(): self.assertRaises(TypeError, foo) + # add test for SPARK-10577 (test broadcast join hint) + def test_functions_broadcast(self): + from pyspark.sql.functions import broadcast + + df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + + # equijoin - should be converted into broadcast join + plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) + + # no join key -- should not be a broadcast join + plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan() + self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) + + # planner should not crash without a join + broadcast(df1)._jdf.queryExecution().executedPlan() + class HiveContextSQLTests(ReusedPySparkTestCase): From 781b21ba2a873ed29394c8dbc74fc700e3e0d17e Mon Sep 17 00:00:00 2001 From: Ewan Leith Date: Mon, 21 Sep 2015 23:43:20 -0700 Subject: [PATCH 0712/1824] [SPARK-10419] [SQL] Adding SQLServer support for datetimeoffset types to JdbcDialects Reading from Microsoft SQL Server over jdbc fails when the table contains datetimeoffset types. This patch registers a SQLServer JDBC Dialect that maps datetimeoffset to a String, as Microsoft suggest. Author: Ewan Leith Closes #8575 from realitymine-coordinator/sqlserver. --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 1 + 2 files changed, 19 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 68ebaaca6c53..c70fea1c3f50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -137,6 +137,8 @@ object JdbcDialects { registerDialect(MySQLDialect) registerDialect(PostgresDialect) registerDialect(DB2Dialect) + registerDialect(MsSqlServerDialect) + /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -260,3 +262,19 @@ case object DB2Dialect extends JdbcDialect { case _ => None } } + +/** + * :: DeveloperApi :: + * Default Microsoft SQL Server dialect, mapping the datetimeoffset types to a String on read. + */ +@DeveloperApi +case object MsSqlServerDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Some(StringType) + } else None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5ab9381de4d6..c4b039a9c535 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -408,6 +408,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) + assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } From 1fcefef06950e2f03477282368ca835bbf40ff24 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Sep 2015 23:46:00 -0700 Subject: [PATCH 0713/1824] [SPARK-10446][SQL] Support to specify join type when calling join with usingColumns JIRA: https://issues.apache.org/jira/browse/SPARK-10446 Currently the method `join(right: DataFrame, usingColumns: Seq[String])` only supports inner join. It is more convenient to have it support other join types. Author: Liang-Chi Hsieh Closes #8600 from viirya/usingcolumns_df. --- python/pyspark/sql/dataframe.py | 6 ++++- .../org/apache/spark/sql/DataFrame.scala | 22 ++++++++++++++++++- .../apache/spark/sql/DataFrameJoinSuite.scala | 13 +++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fb995fa3a76b..80f8d8a0eb37 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -567,7 +567,11 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) elif isinstance(on[0], basestring): - jdf = self._jdf.join(other._jdf, self._jseq(on)) + if how is None: + jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") + else: + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, self._jseq(on), how) else: assert isinstance(on[0], Column), "on should be Column or list of Column" if len(on) > 1: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8f737c202393..ba94d77b2e60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -484,6 +484,26 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { + join(right, usingColumns, "inner") + } + + /** + * Equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @group dfops + * @since 1.6.0 + */ + def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( @@ -502,7 +522,7 @@ class DataFrame private[sql]( Join( joined.left, joined.right, - joinType = Inner, + joinType = JoinType(joinType), condition) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e2716d7841d8..56ad71ea4f48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,19 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - join using multiple columns and specifying join type") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "left"), + Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "right"), + Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") From f24316e6d928c263cbf3872edd97982059c3db22 Mon Sep 17 00:00:00 2001 From: Madhusudanan Kandasamy Date: Tue, 22 Sep 2015 00:03:48 -0700 Subject: [PATCH 0714/1824] [SPARK-10458] [SPARK CORE] Added isStopped() method in SparkContext Added isStopped() method in SparkContext Author: Madhusudanan Kandasamy Closes #8749 from kmadhugit/SPARK-10458. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ebd8e946ee7a..967fec9f42bc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** + * @return true if context is stopped or in the midst of stopping. + */ + def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus From fd61b004877ba4d51c95cd0e08f53bffdf106395 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 00:05:30 -0700 Subject: [PATCH 0715/1824] [Minor] style fix for previous commit f24316e --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 967fec9f42bc..bf3aeb488d59 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** * @return true if context is stopped or in the midst of stopping. */ From 4da32bc0e747fefe847bffe493785d4d16069c04 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 00:07:30 -0700 Subject: [PATCH 0716/1824] [SPARK-8567] [SQL] Increase the timeout of o.a.s.sql.hive.HiveSparkSubmitSuite to 5 minutes. https://issues.apache.org/jira/browse/SPARK-8567 Looks like "SPARK-8368: includes jars passed in through --jars" is pretty flaky now. Based on some history runs, the time spent on a successful run may be from 1.5 minutes to almost 3 minutes. Let's try to increase the timeout and see if we can fix this test. https://amplab.cs.berkeley.edu/jenkins/job/Spark-1.5-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.0,label=spark-test/385/testReport/junit/org.apache.spark.sql.hive/HiveSparkSubmitSuite/SPARK_8368__includes_jars_passed_in_through___jars/history/?start=25 Author: Yin Huai Closes #8850 from yhuai/SPARK-8567-anotherTry. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 97df249bdb6d..5f1660b62d41 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -139,7 +139,7 @@ class HiveSparkSubmitSuite new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - val exitCode = failAfter(180.seconds) { process.waitFor() } + val exitCode = failAfter(300.seconds) { process.waitFor() } if (exitCode != 0) { // include logs in output. Note that logging is async and may not have completed // at the time this exception is raised From f3b727c801408b1cd50e5d9463f2fe0fce654a16 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 00:09:29 -0700 Subject: [PATCH 0717/1824] [SQL] [MINOR] map -> foreach. DataFrame.explain should use foreach to print the explain content. Author: Reynold Xin Closes #8862 from rxin/map-foreach. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ba94d77b2e60..a11140b71736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -320,9 +320,8 @@ class DataFrame private[sql]( * @since 1.3.0 */ def explain(extended: Boolean): Unit = { - ExplainCommand( - queryExecution.logical, - extended = extended).queryExecution.executedPlan.executeCollect().map { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + explain.queryExecution.executedPlan.executeCollect().foreach { // scalastyle:off println r => println(r.getString(0)) // scalastyle:on println From 0bd0e5bed2176b119b3ada590993e153757ea09b Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Tue, 22 Sep 2015 00:14:27 -0700 Subject: [PATCH 0718/1824] =?UTF-8?q?[SPARK-10695]=20[DOCUMENTATION]=20[ME?= =?UTF-8?q?SOS]=20Fixing=20incorrect=20value=20informati=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …on for spark.mesos.constraints parameter. Author: Akash Mishra Closes #8816 from SleepyThread/constraint-fix. --- docs/running-on-mesos.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 460a66f37dd6..ec5a44d79212 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -189,10 +189,10 @@ using `conf.set("spark.cores.max", "10")` (for example). You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} -conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +conf.set("spark.mesos.constraints", "tachyon:true;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `tachyon:true;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. # Mesos Docker Support From 7278f792a73bbcf8d68f38dc2d87cf722693c4cf Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Tue, 22 Sep 2015 11:03:21 +0100 Subject: [PATCH 0719/1824] [SPARK-10718] [BUILD] Update License on conf files and corresponding excludes file update Update License on conf files and corresponding excludes file update Author: Rekha Joshi Author: Joshi Closes #8842 from rekhajoshm/SPARK-10718. --- .rat-excludes | 12 ------------ conf/docker.properties.template | 17 +++++++++++++++++ conf/fairscheduler.xml.template | 18 ++++++++++++++++++ conf/log4j.properties.template | 17 +++++++++++++++++ conf/metrics.properties.template | 17 +++++++++++++++++ conf/slaves.template | 17 +++++++++++++++++ conf/spark-defaults.conf.template | 17 +++++++++++++++++ conf/spark-env.sh.template | 17 +++++++++++++++++ .../spark/log4j-defaults-repl.properties | 17 +++++++++++++++++ .../org/apache/spark/log4j-defaults.properties | 17 +++++++++++++++++ 10 files changed, 154 insertions(+), 12 deletions(-) diff --git a/.rat-excludes b/.rat-excludes index 9165872b9fb2..08fba6d351d6 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -15,20 +15,8 @@ TAGS RELEASE control docs -docker.properties.template -fairscheduler.xml.template -spark-defaults.conf.template -log4j.properties -log4j.properties.template -metrics.properties -metrics.properties.template slaves -slaves.template -spark-env.sh spark-env.cmd -spark-env.sh.template -log4j-defaults.properties -log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 26e3bfd9c5b9..55cb094b4af4 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/fairscheduler.xml.template b/conf/fairscheduler.xml.template index acf59e2a3598..385b2e772d2c 100644 --- a/conf/fairscheduler.xml.template +++ b/conf/fairscheduler.xml.template @@ -1,4 +1,22 @@ + + + FAIR diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 74c5cea94403..f3046be54d7c 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7f17bc7eea4f..d6962e0da2f3 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # syntax: [instance].sink|source.[name].[options]=[value] # This file configures Spark's internal metrics system. The metrics system is diff --git a/conf/slaves.template b/conf/slaves.template index da0a01343d20..be42a638230b 100644 --- a/conf/slaves.template +++ b/conf/slaves.template @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # A Spark Worker will be started on each of the machines listed below. localhost \ No newline at end of file diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template index a48dcc70e136..19cba6e71ed1 100644 --- a/conf/spark-defaults.conf.template +++ b/conf/spark-defaults.conf.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Default system properties included when running spark-submit. # This is useful for setting default environmental settings. diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index c05fe381a36a..990ded420be7 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -1,5 +1,22 @@ #!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index 689afea64f8d..c85abc35b93b 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=WARN, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 27006e45e932..d44cc85dcbd8 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender From 870b8a2edd44c9274c43ca0db4ef5b0998e16fd8 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Tue, 22 Sep 2015 11:05:24 +0100 Subject: [PATCH 0720/1824] [SPARK-10706] [MLLIB] Add java wrapper for random vector rdd Add java wrapper for random vector rdd holdenk srowen Author: Meihua Wu Closes #8841 from rotationsymmetry/SPARK-10706. --- .../spark/mllib/random/RandomRDDs.scala | 42 +++++++++++++++++++ .../mllib/random/JavaRandomRDDsSuite.java | 17 ++++++++ 2 files changed, 59 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index f8ff26b5795b..41d7c4d355f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -855,6 +855,48 @@ object RandomRDDs { sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * Java-friendly version of [[RandomRDDs#randomVectorRDD]]. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols).toJavaRDD() + } + /** * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index fce5f6712f46..5728df5aeebd 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -246,6 +246,23 @@ public void testArbitrary() { Assert.assertEquals(2, rdd.first().length()); } } + + @Test + @SuppressWarnings("unchecked") + public void testRandomVectorRDD() { + UniformGenerator generator = new UniformGenerator(); + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = randomJavaVectorRDD(sc, generator, m, n); + JavaRDD rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); + JavaRDD rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } } // This is just a test generator, it always returns a string of 42 From f4a3c4e34ce93bcaf29c0a35573932880a8b792b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Sep 2015 10:19:08 -0700 Subject: [PATCH 0721/1824] [SPARK-9962] [ML] Decision Tree training: prevNodeIdsForInstances.unpersist() at end of training NodeIdCache: prevNodeIdsForInstances.unpersist() needs to be called at end of training. Author: Holden Karau Closes #8541 from holdenk/SPARK-9962-decission-tree-training-prevNodeIdsForiNstances-unpersist-at-end-of-training. --- .../scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 8 ++++---- .../org/apache/spark/mllib/tree/impl/NodeIdCache.scala | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 488e8e4fb5dc..c5ad8df73fac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -164,10 +164,10 @@ private[spark] class NodeIdCache( } } } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 8f9eb24b57b5..0abed5411143 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -166,6 +166,10 @@ private[spark] class NodeIdCache( fs.delete(new Path(old.getCheckpointFile.get), true) } } + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } From 7104ee0e5dc1290b8b845a0a8ddcdb1875cfd060 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Sep 2015 11:00:33 -0700 Subject: [PATCH 0722/1824] [SPARK-10750] [ML] ML Param validate should print better error information Currently when you set illegal value for params of array type (such as IntArrayParam, DoubleArrayParam, StringArrayParam), it will throw IllegalArgumentException but with incomprehensible error information. Take ```VectorSlicer.setNames``` as an example: ```scala val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") // The value of setNames must be contain distinct elements, so the next line will throw exception. vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4", "f1")) ``` It will throw IllegalArgumentException as: ``` vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. java.lang.IllegalArgumentException: vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. ``` We should distinguish the value of array type from primitive type at Param.validate(value: T), and we will get better error information. ``` vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. java.lang.IllegalArgumentException: vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. ``` Author: Yanbo Liang Closes #8863 from yanboliang/spark-10750. --- .../src/main/scala/org/apache/spark/ml/param/params.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index de32b7218c27..48f6269e57e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -65,7 +65,12 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali */ private[param] def validate(value: T): Unit = { if (!isValid(value)) { - throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.") + val valueToString = value match { + case v: Array[_] => v.mkString("[", ",", "]") + case _ => value.toString + } + throw new IllegalArgumentException( + s"$parent parameter $name given invalid value $valueToString.") } } From 2ea0f2e11b82ef4817c7e6a162ea23da7860b893 Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 22 Sep 2015 11:01:32 -0700 Subject: [PATCH 0723/1824] [SPARK-9585] Delete the input format caching because some input format are non thread safe If we cache the InputFormat, all tasks on the same executor will share it. Some InputFormat is thread safety, but some are not, such as HiveHBaseTableInputFormat. If tasks share a non thread safe InputFormat, unexpected error may be occurs. To avoid it, I think we should delete the input format caching. Author: xutingjun Author: meiyoula <1039320815@qq.com> Author: Xutingjun Closes #7918 from XuTingjun/cached_inputFormat. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8f2655d63b79..77b57132b9f1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -182,17 +182,11 @@ class HadoopRDD[K, V]( } protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { - if (HadoopRDD.containsCachedMetadata(inputFormatCacheKey)) { - return HadoopRDD.getCachedMetadata(inputFormatCacheKey).asInstanceOf[InputFormat[K, V]] - } - // Once an InputFormat for this RDD is created, cache it so that only one reflection call is - // done in each local process. val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] if (newInputFormat.isInstanceOf[Configurable]) { newInputFormat.asInstanceOf[Configurable].setConf(conf) } - HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat) newInputFormat } From 22d40159e60dd27a428e4051ef607292cbffbff3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 22 Sep 2015 11:07:01 -0700 Subject: [PATCH 0724/1824] [SPARK-10593] [SQL] fix resolve output of Generate The output of Generate should not be resolved as Reference. Author: Davies Liu Closes #8755 from davies/view. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 1 - .../catalyst/plans/logical/basicOperators.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 14 ++++++++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02f34cbf58ad..bf72d47ce1ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -378,6 +378,22 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) + // A special case for Generate, because the output of Generate should not be resolved by + // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. + case g @ Generate(generator, join, outer, qualifier, output, child) + if child.resolved && !generator.resolved => + val newG = generator transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldExpr) => + ExtractValue(child, fieldExpr, resolver) + } + if (newG.fastEquals(generator)) { + g + } else { + Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 55286f9f2fc5..0ec9f0857108 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 722f69cdca82..ae9482c10f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -68,7 +68,7 @@ case class Generate( generator.resolved && childrenResolved && generator.elementTypes.length == generatorOutput.length && - !generatorOutput.exists(!_.resolved) + generatorOutput.forall(_.resolved) } // we don't want the gOutput to be taken as part of the expressions diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8126d0233521..bb02473dd17c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1170,4 +1170,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sqlContext.table("`db.t`"), df) } } + + test("SPARK-10593 same column names in lateral view") { + val df = sqlContext.sql( + """ + |select + |insideLayer2.json as a2 + |from (select '{"layer1": {"layer2": "text inside layer 2"}}' json) test + |lateral view json_tuple(json, 'layer1') insideLayer1 as json + |lateral view json_tuple(insideLayer1.json, 'layer2') insideLayer2 as json + """.stripMargin + ) + + checkAnswer(df, Row("text inside layer 2") :: Nil) + } } From 1ca5e2e0b8d8d406c02a74c76ae9d7fc5637c8d3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Sep 2015 11:50:22 -0700 Subject: [PATCH 0725/1824] [SPARK-10704] Rename HashShuffleReader to BlockStoreShuffleReader The current shuffle code has an interface named ShuffleReader with only one implementation, HashShuffleReader. This naming is confusing, since the same read path code is used for both sort- and hash-based shuffle. This patch addresses this by renaming HashShuffleReader to BlockStoreShuffleReader. Author: Josh Rosen Closes #8825 from JoshRosen/shuffle-reader-cleanup. --- ...shShuffleReader.scala => BlockStoreShuffleReader.scala} | 5 ++--- .../org/apache/spark/shuffle/hash/HashShuffleManager.scala | 2 +- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 3 +-- ...eaderSuite.scala => BlockStoreShuffleReaderSuite.scala} | 7 +++---- 4 files changed, 7 insertions(+), 10 deletions(-) rename core/src/main/scala/org/apache/spark/shuffle/{hash/HashShuffleReader.scala => BlockStoreShuffleReader.scala} (97%) rename core/src/test/scala/org/apache/spark/shuffle/{hash/HashShuffleReaderSuite.scala => BlockStoreShuffleReaderSuite.scala} (96%) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala rename to core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0c8f08f0f3b1..6dc9a16e5853 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -private[spark] class HashShuffleReader[K, C]( +private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 0b46634b8b46..d2e2fc4c110a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -51,7 +51,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 476cc1f303da..9df4e551669c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,7 +53,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala similarity index 96% rename from core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 05b3afef5b83..a5eafb1b5529 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer @@ -28,7 +28,6 @@ import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -56,7 +55,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed } } -class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { +class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying @@ -134,7 +133,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { new BaseShuffleHandle(shuffleId, numMaps, dependency) } - val shuffleReader = new HashShuffleReader( + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, From 5017c685f484ec256101d1d33bad11d9e0c0f641 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Sep 2015 12:14:15 -0700 Subject: [PATCH 0726/1824] [SPARK-10740] [SQL] handle nondeterministic expressions correctly for set operations https://issues.apache.org/jira/browse/SPARK-10740 Author: Wenchen Fan Closes #8858 from cloud-fan/non-deter. --- .../sql/catalyst/optimizer/Optimizer.scala | 69 ++++++++++++++----- .../optimizer/SetOperationPushDownSuite.scala | 3 +- .../org/apache/spark/sql/DataFrameSuite.scala | 41 +++++++++++ 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 324f40a051c3..63602eaa8ccd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] { * Intersect: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. */ -object SetOperationPushDown extends Rule[LogicalPlan] { +object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. @@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] { result.asInstanceOf[A] } + /** + * Splits the condition expression into small conditions by `And`, and partition them by + * deterministic, and finally recombine them by `And`. It returns an expression containing + * all deterministic expressions (the first field of the returned Tuple2) and an expression + * containing all non-deterministic expressions (the second field of the returned Tuple2). + */ + private def partitionByDeterministic(condition: Expression): (Expression, Expression) = { + val andConditions = splitConjunctivePredicates(condition) + andConditions.partition(_.deterministic) match { + case (deterministic, nondeterministic) => + deterministic.reduceOption(And).getOrElse(Literal(true)) -> + nondeterministic.reduceOption(And).getOrElse(Literal(true)) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Push down filter into union case Filter(condition, u @ Union(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(u) - Union( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) - - // Push down projection through UNION ALL - case Project(projectList, u @ Union(left, right)) => - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + Filter(nondeterministic, + Union( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) + + // Push down deterministic projection through UNION ALL + case p @ Project(projectList, u @ Union(left, right)) => + if (projectList.forall(_.deterministic)) { + val rewrites = buildRewrites(u) + Union( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + } else { + p + } // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(i) - Intersect( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Intersect( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(e) - Except( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Except( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 3fca47a023dc..1595ad932742 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - SetOperationPushDown) :: Nil + SetOperationPushDown, + SimplifyFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1370713975f2..d919877746c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -916,4 +916,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(intersect.count() === 30) assert(except.count() === 70) } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.unionAll(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } } From 2204cdb28483b249616068085d4e88554fe6acef Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:29:39 -0700 Subject: [PATCH 0727/1824] [SPARK-10672] [SQL] Do not fail when we cannot save the metadata of a data source table in a hive compatible way https://issues.apache.org/jira/browse/SPARK-10672 With changes in this PR, we will fallback to same the metadata of a table in Spark SQL specific way if we fail to save it in a hive compatible way (Hive throws an exception because of its internal restrictions, e.g. binary and decimal types cannot be saved to parquet if the metastore is running Hive 0.13). I manually tested the fix with the following test in `DataSourceWithHiveMetastoreCatalogSuite` (`spark.sql.hive.metastore.version=0.13` and `spark.sql.hive.metastore.jars`=`maven`). ``` test(s"fail to save metadata of a parquet table in hive 0.13") { withTempPath { dir => withTable("t") { val path = dir.getCanonicalPath sql( s"""CREATE TABLE t USING $provider |OPTIONS (path '$path') |AS SELECT 1 AS d1, cast("val_1" as binary) AS d2 """.stripMargin) sql( s"""describe formatted t """.stripMargin).collect.foreach(println) sqlContext.table("t").show } } } } ``` Without this fix, we will fail with the following error. ``` org.apache.hadoop.hive.ql.metadata.HiveException: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:619) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:576) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply$mcV$sp(ClientWrapper.scala:359) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$withHiveState$1.apply(ClientWrapper.scala:256) at org.apache.spark.sql.hive.client.ClientWrapper.retryLocked(ClientWrapper.scala:211) at org.apache.spark.sql.hive.client.ClientWrapper.withHiveState(ClientWrapper.scala:248) at org.apache.spark.sql.hive.client.ClientWrapper.createTable(ClientWrapper.scala:357) at org.apache.spark.sql.hive.HiveMetastoreCatalog.createDataSourceTable(HiveMetastoreCatalog.scala:358) at org.apache.spark.sql.hive.execution.CreateMetastoreDataSourceAsSelect.run(commands.scala:285) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult$lzycompute(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.doExecute(commands.scala:69) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:140) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:138) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:150) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:138) at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:58) at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:58) at org.apache.spark.sql.DataFrame.(DataFrame.scala:144) at org.apache.spark.sql.DataFrame.(DataFrame.scala:129) at org.apache.spark.sql.DataFrame$.apply(DataFrame.scala:51) at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:725) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2$$anonfun$apply$2.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:165) at org.apache.spark.sql.test.SQLTestUtils$class.withTable(SQLTestUtils.scala:150) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTable(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:162) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.test.SQLTestUtils$class.withTempPath(SQLTestUtils.scala:125) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTempPath(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:42) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) at org.scalatest.FunSuite.runTest(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401) at scala.collection.immutable.List.foreach(List.scala:318) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208) at org.scalatest.FunSuite.runTests(FunSuite.scala:1555) at org.scalatest.Suite$class.run(Suite.scala:1424) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.SuperEngine.runImpl(Engine.scala:545) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.org$scalatest$BeforeAndAfterAll$$super$run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:462) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:671) at sbt.ForkMain$Run$2.call(ForkMain.java:294) at sbt.ForkMain$Run$2.call(ForkMain.java:284) at java.util.concurrent.FutureTask.run(FutureTask.java:262) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.getObjectInspector(ArrayWritableObjectInspector.java:108) at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.(ArrayWritableObjectInspector.java:60) at org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe.initialize(ParquetHiveSerDe.java:113) at org.apache.hadoop.hive.metastore.MetaStoreUtils.getDeserializer(MetaStoreUtils.java:339) at org.apache.hadoop.hive.ql.metadata.Table.getDeserializerFromMetaStore(Table.java:288) at org.apache.hadoop.hive.ql.metadata.Table.checkValidity(Table.java:194) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:597) ... 76 more ``` Author: Yin Huai Closes #8824 from yhuai/datasourceMetadata. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 101 +++++++++--------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0c1b41e3377e..012634cb5aeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -309,69 +309,68 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } // TODO: Support persisting partitioned data source relations in Hive compatible format - val hiveTable = (maybeSerDe, dataSource.relation) match { + val qualifiedTableName = tableIdent.quotedString + val (hiveCompitiableTable, logMessage) = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) - if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - // Hive ParquetSerDe doesn't support decimal type until 1.2.0. - val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) - val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) - - val hiveParquetSupportsDecimal = client.version match { - case org.apache.spark.sql.hive.client.hive.v1_2 => true - case _ => false - } - - if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { - // If Hive version is below 1.2.0, we cannot save Hive compatible schema to - // metastore when the file format is Parquet and the schema has DecimalType. - logWarning { - "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + - "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + - s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." - } - newSparkSQLSpecificMetastoreTable() - } else { - logInfo { - "Persisting data source relation with a single input path into Hive metastore in " + - s"Hive compatible format. Input path: ${relation.paths.head}" - } - newHiveCompatibleMetastoreTable(relation, serde) - } + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) + val message = + s"Persisting data source relation $qualifiedTableName with a single input path " + + s"into Hive metastore in Hive compatible format. Input path: ${relation.paths.head}." + (Some(hiveTable), message) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => - logWarning { - "Persisting partitioned data source relation into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting partitioned data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + "Input path(s): " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), relation: HadoopFsRelation) => - logWarning { - "Persisting data source relation with multiple input paths into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting data source relation $qualifiedTableName with multiple input paths into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + s"Input paths: " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), _) => - logWarning { - s"Data source relation is not a ${classOf[HadoopFsRelation].getSimpleName}. " + - "Persisting it into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Data source relation $qualifiedTableName is not a " + + s"${classOf[HadoopFsRelation].getSimpleName}. Persisting it into Hive metastore " + + "in Spark SQL specific format, which is NOT compatible with Hive." + (None, message) case _ => - logWarning { + val message = s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + - "Persisting data source relation into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + s"Persisting data source relation $qualifiedTableName into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive." + (None, message) } - client.createTable(hiveTable) + (hiveCompitiableTable, logMessage) match { + case (Some(table), message) => + // We first try to save the metadata of the table in a Hive compatiable way. + // If Hive throws an error, we fall back to save its metadata in the Spark SQL + // specific way. + try { + logInfo(message) + client.createTable(table) + } catch { + case throwable: Throwable => + val warningMessage = + s"Could not persist $qualifiedTableName in a Hive compatible way. Persisting " + + s"it into Hive metastore in Spark SQL specific format." + logWarning(warningMessage, throwable) + val sparkSqlSpecificTable = newSparkSQLSpecificMetastoreTable() + client.createTable(sparkSqlSpecificTable) + } + + case (None, message) => + logWarning(message) + val hiveTable = newSparkSQLSpecificMetastoreTable() + client.createTable(hiveTable) + } } def hiveDefaultTableFilePath(tableName: String): String = { From 5aea987c904b281d7952ad8db40a32561b4ec5cf Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:31:35 -0700 Subject: [PATCH 0728/1824] [SPARK-10737] [SQL] When using UnsafeRows, SortMergeJoin may return wrong results https://issues.apache.org/jira/browse/SPARK-10737 Author: Yin Huai Closes #8854 from yhuai/SMJBug. --- .../codegen/GenerateProjection.scala | 2 ++ .../apache/spark/sql/execution/Window.scala | 9 ++++-- .../sql/execution/joins/SortMergeJoin.scala | 25 +++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 28 +++++++++++++++++++ 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 2164ddf03d1b..75524b568d68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -171,6 +171,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public Object apply(Object r) { + // GenerateProjection does not work with UnsafeRows. + assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 0269d6d4b7a1..f8929530c503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -253,7 +253,11 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = newProjection(partitionSpec, child.output) + val grouping = if (child.outputsUnsafeRows) { + UnsafeProjection.create(partitionSpec, child.output) + } else { + newProjection(partitionSpec, child.output) + } // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow @@ -277,7 +281,8 @@ case class Window( val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. - val currentGroup = nextGroup + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() rows = new CompactBuffer while (nextRowAvailable && nextGroup == currentGroup) { rows += nextRow.copy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 906f20d2a728..70a1af6a7063 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -56,9 +56,6 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - protected[this] def isUnsafeMode: Boolean = { (codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(leftKeys) @@ -82,6 +79,28 @@ case class SortMergeJoin( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { + // The projection used to extract keys from input rows of the left child. + private[this] val leftKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + // The projection used to extract keys from input rows of the right child. + private[this] val rightKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) private[this] var currentLeftRow: InternalRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 05b4127cbcaf..eca6f1073889 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1781,4 +1781,32 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(1), Row(1))) } } + + test("SortMergeJoin returns wrong results when using UnsafeRows") { + // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. + // This bug will be triggered when Tungsten is enabled and there are multiple + // SortMergeJoin operators executed in the same task. + val confs = + SQLConf.SORTMERGE_JOIN.key -> "true" :: + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: + SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil + withSQLConf(confs: _*) { + val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") + val df2 = + df1 + .join(df1.select(df1("i")), "i") + .select(df1("i"), df1("j")) + + val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1") + val df4 = + df2 + .join(df3, df2("i") === df3("i1")) + .withColumn("diff", $"j" - $"j1") + .select(df2("i"), df2("j"), $"diff") + + checkAnswer( + df4, + df1.withColumn("diff", lit(0))) + } + } } From a96ba40f7ee1352288ea676d8844e1c8174202eb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 14:11:46 -0700 Subject: [PATCH 0729/1824] [SPARK-10714] [SPARK-8632] [SPARK-10685] [SQL] Refactor Python UDF handling This patch refactors Python UDF handling: 1. Extract the per-partition Python UDF calling logic from PythonRDD into a PythonRunner. PythonRunner itself expects iterator as input/output, and thus has no dependency on RDD. This way, we can use PythonRunner directly in a mapPartitions call, or in the future in an environment without RDDs. 2. Use PythonRunner in Spark SQL's BatchPythonEvaluation. 3. Updated BatchPythonEvaluation to only use its input once, rather than twice. This should fix Python UDF performance regression in Spark 1.5. There are a number of small cleanups I wanted to do when I looked at the code, but I kept most of those out so the diff looks small. This basically implements the approach in https://github.com/apache/spark/pull/8833, but with some code moving around so the correctness doesn't depend on the inner workings of Spark serialization and task execution. Author: Reynold Xin Closes #8835 from rxin/python-iter-refactor. --- .../apache/spark/api/python/PythonRDD.scala | 54 ++++++++++--- .../spark/sql/execution/pythonUDFs.scala | 80 +++++++++++-------- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 69da180593bb..3788d1829758 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,6 +24,7 @@ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JM import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials +import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration @@ -38,7 +39,6 @@ import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.util.control.NonFatal private[spark] class PythonRDD( parent: RDD[_], @@ -61,11 +61,39 @@ private[spark] class PythonRDD( if (preservePartitoning) firstParent.partitioner else None } + val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val runner = new PythonRunner( + command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, + bufferSize, reuse_worker) + runner.compute(firstParent.iterator(split, context), split.index, context) + } +} + + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + bufferSize: Int, + reuse_worker: Boolean) + extends Logging { + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map( - f => f.getPath()).mkString(",") + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { envVars.put("SPARK_REUSE_WORKER", "1") @@ -75,7 +103,7 @@ private[spark] class PythonRDD( @volatile var released = false // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, split, context) + val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() @@ -183,13 +211,16 @@ private[spark] class PythonRDD( new InterruptibleIterator(context, stdoutIterator) } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) - /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) + class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { @volatile private var _exception: Exception = null @@ -211,11 +242,11 @@ private[spark] class PythonRDD( val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index - dataOut.writeInt(split.index) + dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.size()) for (include <- pythonIncludes.asScala) { @@ -246,7 +277,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -327,7 +358,8 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() - private def getWorkerBroadcasts(worker: Socket) = { + + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index d0411da6fdf5..c35c726bfc50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} +import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Accumulator, Logging => SparkLogging} +import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -329,7 +329,13 @@ case class EvaluatePython( /** * :: DeveloperApi :: * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * The input data is zipped with the result of the udf evaluation. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ @DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) @@ -342,51 +348,57 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = { - val childResults = child.execute().map(_.copy()) + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - val parent = childResults.mapPartitions { iter => + inputRDD.mapPartitions { iter => EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - iter.grouped(100).map { inputRows => + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => + queue.add(row) EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } - } - val pyRDD = new PythonRDD( - parent, - udf.command, - udf.envVars, - udf.pythonIncludes, - false, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator - ).mapPartitions { iter => - val pickle = new Unpickler - iter.flatMap { pickedResult => - val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - } - }.mapPartitions { iter => + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + udf.command, + udf.envVars, + udf.pythonIncludes, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler val row = new GenericMutableRow(1) - iter.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: InternalRow - } - } + val joined = new JoinedRow - childResults.zip(pyRDD).mapPartitions { iter => - val joinedRow = new JoinedRow() - iter.map { - case (row, udfResult) => - joinedRow(row, udfResult) + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + joined(queue.poll(), row) } } } From 61d4c07f4becb42f054e588be56ed13239644410 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 16:35:43 -0700 Subject: [PATCH 0730/1824] [SPARK-10640] History server fails to parse TaskCommitDenied ... simply because the code is missing! Author: Andrew Or Closes #8828 from andrewor14/task-end-reason-json. --- .../scala/org/apache/spark/TaskEndReason.scala | 6 +++++- .../org/apache/spark/util/JsonProtocol.scala | 13 +++++++++++++ .../apache/spark/util/JsonProtocolSuite.scala | 17 +++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 7137246bc34f..9335c5f4160b 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,13 +17,17 @@ package org.apache.spark -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +// ============================================================================================== +// NOTE: new task end reasons MUST be accompanied with serialization logic in util.JsonProtocol! +// ============================================================================================== + /** * :: DeveloperApi :: * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 99614a786bd9..40729fa5a4ff 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -362,6 +362,10 @@ private[spark] object JsonProtocol { ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) + case taskCommitDenied: TaskCommitDenied => + ("Job ID" -> taskCommitDenied.jobID) ~ + ("Partition ID" -> taskCommitDenied.partitionID) ~ + ("Attempt Number" -> taskCommitDenied.attemptNumber) case ExecutorLostFailure(executorId, isNormalExit) => ("Executor ID" -> executorId) ~ ("Normal Exit" -> isNormalExit) @@ -770,6 +774,7 @@ private[spark] object JsonProtocol { val exceptionFailure = Utils.getFormattedClassName(ExceptionFailure) val taskResultLost = Utils.getFormattedClassName(TaskResultLost) val taskKilled = Utils.getFormattedClassName(TaskKilled) + val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied) val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure) val unknownReason = Utils.getFormattedClassName(UnknownReason) @@ -794,6 +799,14 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled + case `taskCommitDenied` => + // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON + // de/serialization logic was not added until 1.5.1. To provide backward compatibility + // for reading those logs, we need to provide default values for all the fields. + val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => val isNormalExit = Utils.jsonOption(json \ "Normal Exit"). map(_.extract[Boolean]) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 143c1b901df1..a24bf2931cca 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -151,6 +151,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) + testTaskEndReason(TaskCommitDenied(2, 3, 4)) testTaskEndReason(ExecutorLostFailure("100", true)) testTaskEndReason(UnknownReason) @@ -352,6 +353,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) } + // `TaskCommitDenied` was added in 1.3.0 but JSON de/serialization logic was added in 1.5.1 + test("TaskCommitDenied backward compatibility") { + val denied = TaskCommitDenied(1, 2, 3) + val oldDenied = JsonProtocol.taskEndReasonToJson(denied) + .removeField({ _._1 == "Job ID" }) + .removeField({ _._1 == "Partition ID" }) + .removeField({ _._1 == "Attempt Number" }) + val expectedDenied = TaskCommitDenied(-1, -1, -1) + assertEquals(expectedDenied, JsonProtocol.taskEndReasonFromJson(oldDenied)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -577,6 +589,11 @@ class JsonProtocolSuite extends SparkFunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => + case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), + TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => + assert(jobId1 === jobId2) + assert(partitionId1 === partitionId2) + assert(attemptNumber1 === attemptNumber2) case (ExecutorLostFailure(execId1, isNormalExit1), ExecutorLostFailure(execId2, isNormalExit2)) => assert(execId1 === execId2) From 84f81e035e1dab1b42c36563041df6ba16e7b287 Mon Sep 17 00:00:00 2001 From: Zhichao Li Date: Tue, 22 Sep 2015 19:41:57 -0700 Subject: [PATCH 0731/1824] [SPARK-10310] [SQL] Fixes script transformation field/line delimiters **Please attribute this PR to `Zhichao Li `.** This PR is based on PR #8476 authored by zhichao-li. It fixes SPARK-10310 by adding field delimiter SerDe property to the default `LazySimpleSerDe`, and enabling default record reader/writer classes. Currently, we only support `LazySimpleSerDe`, used together with `TextRecordReader` and `TextRecordWriter`, and don't support customizing record reader/writer using `RECORDREADER`/`RECORDWRITER` clauses. This should be addressed in separate PR(s). Author: Cheng Lian Closes #8860 from liancheng/spark-10310/fix-script-trans-delimiters. --- .../org/apache/spark/sql/hive/HiveQl.scala | 52 ++++++++++--- .../hive/execution/ScriptTransformation.scala | 75 +++++++++++++++---- .../resources/data/scripts/test_transform.py | 6 ++ .../sql/hive/execution/SQLQuerySuite.scala | 39 ++++++++++ .../execution/ScriptTransformationSuite.scala | 2 + 5 files changed, 152 insertions(+), 22 deletions(-) create mode 100755 sql/hive/src/test/resources/data/scripts/test_transform.py diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index d5cd7e98b526..256440a9a2e9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException @@ -884,16 +885,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { + type SerDeInfo = ( + Seq[(String, String)], // Input row format information + Option[String], // Optional input SerDe class + Seq[(String, String)], // Input SerDe properties + Boolean // Whether to use default record reader/writer + ) + + def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, None, Nil) + (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", @@ -903,20 +910,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (BaseSemanticAnalyzer.unescapeSQLString(name), BaseSemanticAnalyzer.unescapeSQLString(value)) } - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil) + // SPARK-10310: Special cases LazySimpleSerDe + // TODO Fully supports user-defined record reader/writer classes + val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) + val useDefaultRecordReaderWriter = + unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName + (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) + + case Nil => + // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here + val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") + (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) } - val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) - val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause) + val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = + matchSerDe(inputSerdeClause) + + val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = + matchSerDe(outputSerdeClause) val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + // TODO Adds support for user-defined record reader/writer classes + val recordReaderClass = if (useDefaultRecordReader) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) + } else { + None + } + + val recordWriterClass = if (useDefaultRecordWriter) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) + } else { + None + } + val schema = HiveScriptIOSchema( inRowFormat, outRowFormat, inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, schemaLess) + inSerdeProps, outSerdeProps, + recordReaderClass, recordWriterClass, + schemaLess) Some( logical.ScriptTransformation( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 32bddbaeaeaf..b30117f0de99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -24,20 +24,22 @@ import javax.annotation.Nullable import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} import org.apache.spark.{Logging, TaskContext} /** @@ -58,6 +60,8 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) + protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) @@ -67,6 +71,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream + val localHiveConf = serializedHiveConf.value // In order to avoid deadlocks, we need to consume the error output of the child process. // To avoid issues caused by large error output, we use a circular buffer to limit the amount @@ -96,7 +101,8 @@ case class ScriptTransformation( outputStream, proc, stderrBuffer, - TaskContext.get() + TaskContext.get(), + localHiveConf ) // This nullability is a performance optimization in order to avoid an Option.foreach() call @@ -109,6 +115,10 @@ case class ScriptTransformation( val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) + + @Nullable val scriptOutputReader = + ioschema.recordReader(scriptOutputStream, localHiveConf).orNull + var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { outputSerde.getSerializedClass().newInstance @@ -134,15 +144,25 @@ case class ScriptTransformation( } } else if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject - try { - scriptOutputWritable.readFields(scriptOutputStream) - true - } catch { - case _: EOFException => - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } + + if (scriptOutputReader != null) { + if (scriptOutputReader.next(scriptOutputWritable) <= 0) { + writerThread.exception.foreach(throw _) false + } else { + true + } + } else { + try { + scriptOutputWritable.readFields(scriptOutputStream) + true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } } } else { true @@ -210,7 +230,8 @@ private class ScriptTransformationWriterThread( outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, - taskContext: TaskContext + taskContext: TaskContext, + conf: Configuration ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { setDaemon(true) @@ -224,6 +245,7 @@ private class ScriptTransformationWriterThread( TaskContext.setTaskContext(taskContext) val dataOutputStream = new DataOutputStream(outputStream) + @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception @@ -250,7 +272,12 @@ private class ScriptTransformationWriterThread( } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + + if (scriptInputWriter != null) { + scriptInputWriter.write(writable) + } else { + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + } } } outputStream.close() @@ -290,6 +317,8 @@ case class HiveScriptIOSchema ( outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { private val defaultFormat = Map( @@ -347,4 +376,24 @@ case class HiveScriptIOSchema ( serde } + + def recordReader( + inputStream: InputStream, + conf: Configuration): Option[RecordReader] = { + recordReaderClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val props = new Properties() + props.putAll(outputSerdeProps.toMap.asJava) + instance.initialize(inputStream, conf, props) + instance + } + } + + def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { + recordWriterClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + instance.initialize(outputStream, conf) + instance + } + } } diff --git a/sql/hive/src/test/resources/data/scripts/test_transform.py b/sql/hive/src/test/resources/data/scripts/test_transform.py new file mode 100755 index 000000000000..ac6d11d8b919 --- /dev/null +++ b/sql/hive/src/test/resources/data/scripts/test_transform.py @@ -0,0 +1,6 @@ +import sys + +delim = sys.argv[1] + +for row in sys.stdin: + print(delim.join([w + '#' for w in row[:-1].split(delim)])) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index bb02473dd17c..71823e32ad38 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1184,4 +1184,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(df, Row("text inside layer 2") :: Nil) } + + test("SPARK-10310: " + + "script transformation using default input/output SerDe and record reader/writer") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + checkAnswer( + sql( + """FROM( + | FROM test SELECT TRANSFORM(a, b) + | USING 'python src/test/resources/data/scripts/test_transform.py "\t"' + | AS (c STRING, d STRING) + |) t + |SELECT c + """.stripMargin), + (0 until 5).map(i => Row(i + "#"))) + } + + test("SPARK-10310: script transformation using LazySimpleSerDe") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + val df = sql( + """FROM test + |SELECT TRANSFORM(a, b) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + |USING 'python src/test/resources/data/scripts/test_transform.py "|"' + |AS (c STRING, d STRING) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + """.stripMargin) + + checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#"))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index cb8d0fca8e69..7cfdb886b585 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -38,6 +38,8 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { outputSerdeClass = None, inputSerdeProps = Seq.empty, outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, schemaLess = false ) From 558e9c7e60a7c0d85ba26634e97562ad2163e91d Mon Sep 17 00:00:00 2001 From: Matt Hagen Date: Tue, 22 Sep 2015 21:14:25 -0700 Subject: [PATCH 0732/1824] [SPARK-10663] Removed unnecessary invocation of DataFrame.toDF method. The Scala example under the "Example: Pipeline" heading in this document initializes the "test" variable to a DataFrame. Because test is already a DF, there is not need to call test.toDF as the example does in a subsequent line: model.transform(test.toDF). So, I removed the extraneous toDF invocation. Author: Matt Hagen Closes #8875 from hagenhaus/SPARK-10663. --- docs/ml-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 0427ac6695aa..fd3a6167bc65 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -475,7 +475,7 @@ val test = sqlContext.createDataFrame(Seq( )).toDF("id", "text") // Make predictions on test documents. -model.transform(test.toDF) +model.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => From 5548a254755bb84edae2768b94ab1816e1b49b91 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Sep 2015 22:44:09 -0700 Subject: [PATCH 0733/1824] [SPARK-10652] [SPARK-10742] [STREAMING] Set meaningful job descriptions for all streaming jobs Here is the screenshot after adding the job descriptions to threads that run receivers and the scheduler thread running the batch jobs. ## All jobs page * Added job descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9924165/cda4a372-5cb1-11e5-91ca-d43a32c699e9.png) ## All stages page * Added stage descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9923814/2cce266a-5cae-11e5-8a3f-dad84d06c50e.png) ## Streaming batch details page * Added the +details link ![image](https://cloud.githubusercontent.com/assets/663212/9921977/24014a32-5c98-11e5-958e-457b6c38065b.png) Author: Tathagata Das Closes #8791 from tdas/SPARK-10652. --- .../scala/org/apache/spark/ui/UIUtils.scala | 62 ++++++++++++++++- .../apache/spark/ui/jobs/AllJobsPage.scala | 14 ++-- .../org/apache/spark/ui/jobs/StageTable.scala | 7 +- .../org/apache/spark/ui/UIUtilsSuite.scala | 66 +++++++++++++++++++ .../spark/streaming/StreamingContext.scala | 4 +- .../streaming/scheduler/JobScheduler.scala | 15 ++++- .../streaming/scheduler/ReceiverTracker.scala | 5 +- .../apache/spark/streaming/ui/BatchPage.scala | 33 ++++++---- .../streaming/StreamingContextSuite.scala | 2 +- 9 files changed, 179 insertions(+), 29 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f2da41772410..21dc8f0b6548 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -18,9 +18,11 @@ package org.apache.spark.ui import java.text.SimpleDateFormat -import java.util.{Locale, Date} +import java.util.{Date, Locale} -import scala.xml.{Node, Text, Unparsed} +import scala.util.control.NonFatal +import scala.xml._ +import scala.xml.transform.{RewriteRule, RuleTransformer} import org.apache.spark.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -395,4 +397,60 @@ private[spark] object UIUtils extends Logging { } + /** + * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML + * and make sure that it only contains anchors with root-relative links. Otherwise, + * the whole string will rendered as a simple escaped text. + * + * Note: In terms of security, only anchor tags with root relative links are supported. So any + * attempts to embed links outside Spark UI, or other tags like } private def createExecutorTable() : Seq[Node] = { From 9631ca35275b0ce8a5219f975907ac36ed11f528 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 18 Nov 2015 08:59:20 +0000 Subject: [PATCH 1410/1824] [SPARK-11652][CORE] Remote code execution with InvokerTransformer Update to Commons Collections 3.2.2 to avoid any potential remote code execution vulnerability Author: Sean Owen Closes #9731 from srowen/SPARK-11652. --- pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pom.xml b/pom.xml index 940e2d8740bf..ad849112ce76 100644 --- a/pom.xml +++ b/pom.xml @@ -162,6 +162,8 @@ 3.1 3.4.1 + + 3.2.2 2.10.5 2.10 ${scala.version} @@ -475,6 +477,11 @@ commons-math3 ${commons.math3.version} + + org.apache.commons + commons-collections + ${commons.collections.version} + org.apache.ivy ivy From 1429e0a2b562469146b6fa06051c85a00092e5b8 Mon Sep 17 00:00:00 2001 From: Viveka Kulharia Date: Wed, 18 Nov 2015 09:10:15 +0000 Subject: [PATCH 1411/1824] rmse was wrongly calculated It was multiplying with U instaed of dividing by U Author: Viveka Kulharia Closes #9771 from vivkul/patch-1. --- examples/src/main/python/als.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 1c3a787bd0e9..205ca02962be 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,7 +36,7 @@ def rmse(R, ms, us): diff = R - ms * us.T - return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + return np.sqrt(np.sum(np.power(diff, 2)) / (M * U)) def update(i, vec, mat, ratings): From 3a6807fdf07b0e73d76502a6bd91ad979fde8b61 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 18 Nov 2015 08:18:54 -0800 Subject: [PATCH 1412/1824] =?UTF-8?q?[SPARK-11804]=20[PYSPARK]=20Exception?= =?UTF-8?q?=20raise=20when=20using=20Jdbc=20predicates=20opt=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ion in PySpark Author: Jeff Zhang Closes #9791 from zjffdu/SPARK-11804. --- python/pyspark/sql/readwriter.py | 10 +++++----- python/pyspark/sql/utils.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7b8ddb9feba3..e8f0d7ec7703 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -26,6 +26,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * +from pyspark.sql import utils __all__ = ["DataFrameReader", "DataFrameWriter"] @@ -131,9 +132,7 @@ def load(self, path=None, format=None, schema=None, **options): if type(path) == list: paths = path gateway = self._sqlContext._sc._gateway - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) - for i in range(0, len(paths)): - jpaths[i] = paths[i] + jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) return self._df(self._jreader.load(jpaths)) else: return self._df(self._jreader.load(path)) @@ -269,8 +268,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) - return self._df(self._jreader.jdbc(url, table, arr, jprop)) + gateway = self._sqlContext._sc._gateway + jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) + return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index c4fda8bd3b89..b0a0373372d2 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -71,3 +71,16 @@ def install_exception_handler(): patched = capture_sql_exception(original) # only patch the one used in in py4j.java_gateway (call Java API) py4j.java_gateway.get_return_value = patched + + +def toJArray(gateway, jtype, arr): + """ + Convert python list to java type array + :param gateway: Py4j Gateway + :param jtype: java type of element in array + :param arr: python type list + """ + jarr = gateway.new_array(jtype, len(arr)) + for i in range(0, len(arr)): + jarr[i] = arr[i] + return jarr From a97d6f3a5861e9f2bbe36957e3b39f835f3e214c Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 18 Nov 2015 08:32:03 -0800 Subject: [PATCH 1413/1824] [SPARK-11281][SPARKR] Add tests covering the issue. The goal of this PR is to add tests covering the issue to ensure that is was resolved by [SPARK-11086](https://issues.apache.org/jira/browse/SPARK-11086). Author: zero323 Closes #9743 from zero323/SPARK-11281-tests. --- R/pkg/inst/tests/test_sparkSQL.R | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 8ff06276599e..87ab33f6384b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -229,7 +229,7 @@ test_that("create DataFrame from list or data.frame", { df <- createDataFrame(sqlContext, l, c("a", "b")) expect_equal(columns(df), c("a", "b")) - l <- list(list(a=1, b=2), list(a=3, b=4)) + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) df <- createDataFrame(sqlContext, l) expect_equal(columns(df), c("a", "b")) @@ -292,11 +292,15 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { - ldf <- data.frame(row.names=1:2) + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) + ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) + sdf <- createDataFrame(sqlContext, ldf) + collected <- collect(sdf) - expect_equivalent(ldf, collect(sdf)) + expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE]) + expect_equal(ldf$an_envir, collected$an_envir) }) # For test map type and struct type in DataFrame From 224723e6a8b198ef45d6c5ca5d2f9c61188ada8f Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 18 Nov 2015 08:41:45 -0800 Subject: [PATCH 1414/1824] [SPARK-11773][SPARKR] Implement collection functions in SparkR. Author: Sun Rui Closes #9764 from sun-rui/SPARK-11773. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 109 ++++++++++++++++++++++--------- R/pkg/R/generics.R | 10 ++- R/pkg/R/utils.R | 2 +- R/pkg/inst/tests/test_sparkSQL.R | 10 +++ 6 files changed, 100 insertions(+), 35 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2ee7d6f94f1b..260c9edce62e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -98,6 +98,7 @@ exportMethods("%in%", "add_months", "alias", "approxCountDistinct", + "array_contains", "asc", "ascii", "asin", @@ -215,6 +216,7 @@ exportMethods("%in%", "sinh", "size", "skewness", + "sort_array", "soundex", "stddev", "stddev_pop", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fd105ba5bc9b..34177e3cdd94 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2198,4 +2198,4 @@ setMethod("coltypes", rTypes[naIndices] <- types[naIndices] rTypes - }) \ No newline at end of file + }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3d0255a62f15..ff0f438045c1 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -373,22 +373,6 @@ setMethod("exp", column(jc) }) -#' explode -#' -#' Creates a new row for each element in the given array or map column. -#' -#' @rdname explode -#' @name explode -#' @family collection_funcs -#' @export -#' @examples \dontrun{explode(df$c)} -setMethod("explode", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) - column(jc) - }) - #' expm1 #' #' Computes the exponential of the given value minus one. @@ -980,22 +964,6 @@ setMethod("sinh", column(jc) }) -#' size -#' -#' Returns length of array or map. -#' -#' @rdname size -#' @name size -#' @family collection_funcs -#' @export -#' @examples \dontrun{size(df$c)} -setMethod("size", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) - column(jc) - }) - #' skewness #' #' Aggregate function: returns the skewness of the values in a group. @@ -2365,3 +2333,80 @@ setMethod("rowNumber", jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber") column(jc) }) + +###################### Collection functions###################### + +#' array_contains +#' +#' Returns true if the array contain the value. +#' +#' @param x A Column +#' @param value A value to be checked if contained in the column +#' @rdname array_contains +#' @name array_contains +#' @family collection_funcs +#' @export +#' @examples \dontrun{array_contains(df$c, 1)} +setMethod("array_contains", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_contains", x@jc, value) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' sort_array +#' +#' Sorts the input array for the given column in ascending order, +#' according to the natural ordering of the array elements. +#' +#' @param x A Column to sort +#' @param asc A logical flag indicating the sorting order. +#' TRUE, sorting is in ascending order. +#' FALSE, sorting is in descending order. +#' @rdname sort_array +#' @name sort_array +#' @family collection_funcs +#' @export +#' @examples +#' \dontrun{ +#' sort_array(df$c) +#' sort_array(df$c, FALSE) +#' } +setMethod("sort_array", + signature(x = "Column"), + function(x, asc = TRUE) { + jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index afdeffc2abd8..0dcd05438222 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -644,6 +644,10 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @export setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) +#' @rdname array_contains +#' @export +setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) + #' @rdname ascii #' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -961,6 +965,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @export setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname sort_array +#' @export +setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) @@ -1076,4 +1084,4 @@ setGeneric("with") #' @rdname coltypes #' @export -setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index db3b2c4bbd79..45c77a86c958 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -635,4 +635,4 @@ assignNewEnv <- function(data) { assign(x = cols[i], value = data[, cols[i]], envir = env) } env -} \ No newline at end of file +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 87ab33f6384b..d9a94faff7ac 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -878,6 +878,16 @@ test_that("column functions", { df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") + + # Test array_contains() and sort_array() + df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) + result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] + expect_equal(result, c(TRUE, FALSE)) + + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + result <- collect(select(df, sort_array(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) }) # test_that("column binary mathfunctions", { From 3cca5ffb3d60d5de9a54bc71cf0b8279898936d2 Mon Sep 17 00:00:00 2001 From: Hurshal Patel Date: Wed, 18 Nov 2015 09:28:59 -0800 Subject: [PATCH 1415/1824] [SPARK-11195][CORE] Use correct classloader for TaskResultGetter Make sure we are using the context classloader when deserializing failed TaskResults instead of the Spark classloader. The issue is that `enqueueFailedTask` was using the incorrect classloader which results in `ClassNotFoundException`. Adds a test in TaskResultGetterSuite that compiles a custom exception, throws it on the executor, and asserts that Spark handles the TaskResult deserialization instead of returning `UnknownReason`. See #9367 for previous comments See SPARK-11195 for a full repro Author: Hurshal Patel Closes #9779 from choochootrain/spark-11195-master. --- .../scala/org/apache/spark/TestUtils.scala | 11 ++-- .../spark/scheduler/TaskResultGetter.scala | 4 +- .../scheduler/TaskResultGetterSuite.scala | 65 ++++++++++++++++++- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index acfe751f6c74..43c89b258f2f 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} import java.nio.charset.StandardCharsets +import java.nio.file.Paths import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} @@ -83,15 +84,15 @@ private[spark] object TestUtils { } /** - * Create a jar file that contains this set of files. All files will be located at the root - * of the jar. + * Create a jar file that contains this set of files. All files will be located in the specified + * directory or at the root of the jar. */ - def createJar(files: Seq[File], jarFile: File): URL = { + def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = { val jarFileStream = new FileOutputStream(jarFile) val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) for (file <- files) { - val jarEntry = new JarEntry(file.getName) + val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString) jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) @@ -123,7 +124,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - // Calling this outputs a class file in pwd. It's easier to just rename the file than + // Calling this outputs a class file in pwd. It's easier to just rename the files than // build a custom FileManager that controls the output location. val options = if (classpathUrls.nonEmpty) { Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 46a6f6537e2e..f4965994d827 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { + val loader = Utils.getContextOrSparkClassLoader try { if (serializedData != null && serializedData.limit() > 0) { reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + serializedData, loader) } } catch { case cnd: ClassNotFoundException => // Log an error but keep going here -- the task failed, so not catastrophic // if we can't deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Exception => {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 815caa79ff52..bc72c3685e8c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.io.File +import java.net.URL import java.nio.ByteBuffer import scala.concurrent.duration._ @@ -26,8 +28,10 @@ import scala.util.control.NonFatal import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId +import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // Make sure two tasks were run (one failed one, and a second retried one). assert(scheduler.nextTaskId.get() === 2) } + + /** + * Make sure we are using the context classloader when deserializing failed TaskResults instead + * of the Spark classloader. + + * This test compiles a jar containing an exception and tests that when it is thrown on the + * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown + * exception as the cause. + + * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing + * the exception, resulting in an UnknownReason for the TaskEndResult. + */ + test("failed task deserialized with the correct classloader (SPARK-11195)") { + // compile a small jar containing an exception that will be thrown on an executor. + val tempDir = Utils.createTempDir() + val srcDir = new File(tempDir, "repro/") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + """package repro; + | + |public class MyException extends Exception { + |} + """.stripMargin) + val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) + TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro")) + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + loader.addURL(jarFile.toURI.toURL) + Thread.currentThread().setContextClassLoader(loader) + val excClass: Class[_] = Utils.classForName("repro.MyException") + + // NOTE: we must run the cluster with "local" so that the executor can load the compiled + // jar. + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + val exc = excClass.newInstance().asInstanceOf[Exception] + throw exc + } + + // the driver should not have any problems resolving the exception class and determining + // why the task failed. + val exceptionMessage = intercept[SparkException] { + rdd.collect() + }.getMessage + + val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + + assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) + assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } } From cffb899c4397ecccedbcc41e7cf3da91f953435a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:15:50 -0800 Subject: [PATCH 1416/1824] [SPARK-11803][SQL] fix Dataset self-join When we resolve the join operator, we may change the output of right side if self-join is detected. So in `Dataset.joinWith`, we should resolve the join operator first, and then get the left output and right output from it, instead of using `left.output` and `right.output` directly. Author: Wenchen Fan Closes #9806 from cloud-fan/self-join. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 14 +++++++++----- .../scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 817c20fdbb9f..b644f6ad3096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -498,13 +498,17 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan + val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr))) + val leftOutput = joined.analyzed.output.take(left.output.length) + val rightOutput = joined.analyzed.output.takeRight(right.output.length) + val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(left.output.head, "_1")() - case _ => Alias(CreateStruct(left.output), "_1")() + case e if e.flat => Alias(leftOutput.head, "_1")() + case _ => Alias(CreateStruct(leftOutput), "_1")() } val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(right.output.head, "_2")() - case _ => Alias(CreateStruct(right.output), "_2")() + case e if e.flat => Alias(rightOutput.head, "_2")() + case _ => Alias(CreateStruct(rightOutput), "_2")() } @@ -513,7 +517,7 @@ class Dataset[T] private[sql]( withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, - Join(left, right, Inner, Some(condition.expr))) + joined.analyzed) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a522894c374f..198962b8fb75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -347,7 +347,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(joined, ("2", 2)) } - ignore("self join") { + test("self join") { val ds = Seq("1", "2").toDS().as("a") val joined = ds.joinWith(ds, lit(true)) checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) @@ -360,15 +360,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.groupBy(p => p).count().collect().toSeq == Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - ignore("kryo encoder self join") { + test("kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == Set( (KryoData(1), KryoData(1)), From 33b837333435ceb0c04d1f361a5383c4fe6a5a75 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:23:12 -0800 Subject: [PATCH 1417/1824] [SPARK-11725][SQL] correctly handle null inputs for UDF If user use primitive parameters in UDF, there is no way for him to do the null-check for primitive inputs, so we are assuming the primitive input is null-propagatable for this case and return null if the input is null. Author: Wenchen Fan Closes #9770 from cloud-fan/udf. --- .../spark/sql/catalyst/ScalaReflection.scala | 9 ++++ .../sql/catalyst/analysis/Analyzer.scala | 32 +++++++++++++- .../sql/catalyst/expressions/ScalaUDF.scala | 6 +++ .../sql/catalyst/ScalaReflectionSuite.scala | 17 +++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 44 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++ 6 files changed, 121 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b3dd351e38e..38828e59a215 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -719,6 +719,15 @@ trait ScalaReflection { } } + /** + * Returns classes of input parameters of scala function object. + */ + def getParameterTypes(func: AnyRef): Seq[Class[_]] = { + val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) + assert(methods.length == 1) + methods.head.getParameterTypes + } + def typeOfObject: PartialFunction[Any, DataType] = { // The data type can be determined without ambiguity. case obj: Boolean => BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f4670b55bdb..f00c451b5981 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} +import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ /** @@ -85,6 +85,8 @@ class Analyzer( extendedResolutionRules : _*), Batch("Nondeterministic", Once, PullOutNondeterministic), + Batch("UDF", Once, + HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -1063,6 +1065,34 @@ class Analyzer( Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } + + /** + * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the + * null check. When user defines a UDF with primitive parameters, there is no way to tell if the + * primitive parameter is null or not, so here we assume the primitive input is null-propagatable + * and we should return null if the input is null. + */ + object HandleNullInputsForUDF extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. + + case plan => plan transformExpressionsUp { + + case udf @ ScalaUDF(func, _, inputs, _) => + val parameterTypes = ScalaReflection.getParameterTypes(func) + assert(parameterTypes.length == inputs.length) + + val inputsNullCheck = parameterTypes.zip(inputs) + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } + .filter { case (cls, _) => cls.isPrimitive } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3388cc20a980..03b89221ef2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType /** * User-defined function. + * @param function The user defined scala function to run. + * Note that if you use primitive parameters, you are not able to check if it is + * null or not, and the UDF will return null for you if the primitive input is + * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. + * @param children The input expressions of this UDF. + * @param inputTypes The expected input types of this UDF. */ case class ScalaUDF( function: AnyRef, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 3b848cfdf737..4ea410d492b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -280,4 +280,21 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType)) } } + + test("get parameter type from a function object") { + val primitiveFunc = (i: Int, j: Long) => "x" + val primitiveTypes = getParameterTypes(primitiveFunc) + assert(primitiveTypes.forall(_.isPrimitive)) + assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) + + val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" + val boxedTypes = getParameterTypes(boxedFunc) + assert(boxedTypes.forall(!_.isPrimitive)) + assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) + + val anyFunc = (i: Any, j: AnyRef) => "x" + val anyTypes = getParameterTypes(anyFunc) + assert(anyTypes.forall(!_.isPrimitive)) + assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 65f09b46afae..08586a97411a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -174,4 +174,48 @@ class AnalysisSuite extends AnalysisTest { ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val string = testRelation2.output(0) + val double = testRelation2.output(2) + val short = testRelation2.output(4) + val nullResult = Literal.create(null, StringType) + + def checkUDF(udf: Expression, transformed: Expression): Unit = { + checkAnalysis( + Project(Alias(udf, "")() :: Nil, testRelation2), + Project(Alias(transformed, "")() :: Nil, testRelation2) + ) + } + + // non-primitive parameters do not need special null handling + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil) + val expected1 = udf1 + checkUDF(udf1, expected1) + + // only primitive parameter needs special null handling + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) + val expected2 = If(IsNull(double), nullResult, udf2) + checkUDF(udf2, expected2) + + // special null handling should apply to all primitive parameters + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val expected3 = If( + IsNull(short) || IsNull(double), + nullResult, + udf3) + checkUDF(udf3, expected3) + + // we can skip special null handling for primitive parameters that are not nullable + // TODO: this is disabled for now as we can not completely trust `nullable`. + val udf4 = ScalaUDF( + (s: Short, d: Double) => "x", + StringType, + short :: double.withNullability(false) :: Nil) + val expected4 = If( + IsNull(short), + nullResult, + udf4) + // checkUDF(udf4, expected4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 35cdab50bdec..5a7f24684d1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1115,4 +1115,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select(df("*")), Row(1, "a")) checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val df = Seq( + new java.lang.Integer(22) -> "John", + null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name") + + val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { + (i: java.lang.Integer) => if (i == null) null else i * 2 + } + checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil) + + val primitiveUDF = udf((i: Int) => i * 2) + checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) + } } From dbf428c87ab34b6f76c75946043bdf5f60c9b1b3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:33:17 -0800 Subject: [PATCH 1418/1824] [SPARK-11795][SQL] combine grouping attributes into a single NamedExpression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the input encoder should point to a struct type field if it’s non-flat. However, our keyEncoder always point to a flat field/fields: `groupingAttributes`, we should combine them into a single `NamedExpression`. Author: Wenchen Fan Closes #9792 from cloud-fan/agg. --- .../main/scala/org/apache/spark/sql/GroupedDataset.scala | 9 +++++++-- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 ++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index c66162ee2148..3f84e22a1025 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedTEncoder, dataAttributes).named) - val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) + val keyColumn = if (groupingAttributes.length > 1) { + Alias(CreateStruct(groupingAttributes), "key")() + } else { + groupingAttributes.head + } + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) new Dataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 198962b8fb75..b6db583dfe01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } - ignore("Dataset should set the resolved encoders internally for maps") { - // TODO: Enable this once we fix SPARK-11793. + test("map and group by with class data") { // We inject a group by here to make sure this test case is future proof // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() @@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds, - (ClassData("one", 1), 1L), (ClassData("two", 2), 1L)) + (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } test("select") { From 90a7519daaa7f4ee3be7c5a9aa244120811ff6eb Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 18 Nov 2015 11:35:41 -0800 Subject: [PATCH 1419/1824] [MINOR][BUILD] Ignore ensime cache Using ENSIME, I often have `.ensime_cache` polluting my source tree. This PR simply adds the cache directory to `.gitignore` Author: Jakob Odersky Closes #9708 from jodersky/master. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 08f2d8f7543f..07524bc429e9 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ spark-tests.log streaming-tests.log dependency-reduced-pom.xml .ensime +.ensime_cache/ .ensime_lucene checkpoint derby.log From 6f99522d13d8db9fcc767f7c3189557b9a53d283 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 11:49:12 -0800 Subject: [PATCH 1420/1824] [SPARK-11792] [SQL] [FOLLOW-UP] Change SizeEstimation to KnownSizeEstimation and make estimatedSize return Long instead of Option[Long] https://issues.apache.org/jira/browse/SPARK-11792 The main changes include: * Renaming `SizeEstimation` to `KnownSizeEstimation`. Hopefully this new name has more information. * Making `estimatedSize` return `Long` instead of `Option[Long]`. * In `UnsaveHashedRelation`, `estimatedSize` will delegate the work to `SizeEstimator` if we have not created a `BytesToBytesMap`. Since we will put `UnsaveHashedRelation` to `BlockManager`, it is generally good to let it provide a more accurate size estimation. Also, if we do not put `BytesToBytesMap` directly into `BlockerManager`, I feel it is not really necessary to make `BytesToBytesMap` extends `KnownSizeEstimation`. Author: Yin Huai Closes #9813 from yhuai/SPARK-11792-followup. --- .../org/apache/spark/util/SizeEstimator.scala | 30 ++++++++++--------- .../spark/util/SizeEstimatorSuite.scala | 14 ++------- .../sql/execution/joins/HashedRelation.scala | 12 +++++--- 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index c3a2675ee5f4..09864e3f8392 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -36,9 +36,14 @@ import org.apache.spark.util.collection.OpenHashSet * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. + * The difference between a [[KnownSizeEstimation]] and + * [[org.apache.spark.util.collection.SizeTracker]] is that, a + * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to + * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without + * using [[SizeEstimator]]. */ -private[spark] trait SizeEstimation { - def estimatedSize: Option[Long] +private[spark] trait KnownSizeEstimation { + def estimatedSize: Long } /** @@ -209,18 +214,15 @@ object SizeEstimator extends Logging { // the size estimator since it references the whole REPL. Do nothing in this case. In // general all ClassLoaders and Classes will be shared between objects anyway. } else { - val estimatedSize = obj match { - case s: SizeEstimation => s.estimatedSize - case _ => None - } - if (estimatedSize.isDefined) { - state.size += estimatedSize.get - } else { - val classInfo = getClassInfo(cls) - state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) - } + obj match { + case s: KnownSizeEstimation => + state.size += s.estimatedSize + case _ => + val classInfo = getClassInfo(cls) + state.size += alignSize(classInfo.shellSize) + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 9b6261af123e..101610e38014 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -60,16 +60,10 @@ class DummyString(val arr: Array[Char]) { @transient val hash32: Int = 0 } -class DummyClass8 extends SizeEstimation { +class DummyClass8 extends KnownSizeEstimation { val x: Int = 0 - override def estimatedSize: Option[Long] = Some(2015) -} - -class DummyClass9 extends SizeEstimation { - val x: Int = 0 - - override def estimatedSize: Option[Long] = None + override def estimatedSize: Long = 2015 } class SizeEstimatorSuite @@ -231,9 +225,5 @@ class SizeEstimatorSuite // DummyClass8 provides its size estimation. assertResult(2015)(SizeEstimator.estimate(new DummyClass8)) assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8))) - - // DummyClass9 does not provide its size estimation. - assertResult(16)(SizeEstimator.estimate(new DummyClass9)) - assertResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass9))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 49ae09bf5378..aebfea583240 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.{SizeEstimation, Utils} +import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -190,7 +190,7 @@ private[execution] object HashedRelation { private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation - with SizeEstimation + with KnownSizeEstimation with Externalizable { private[joins] def this() = this(null) // Needed for serialization @@ -217,8 +217,12 @@ private[joins] final class UnsafeHashedRelation( } } - override def estimatedSize: Option[Long] = { - Option(binaryMap).map(_.getTotalMemoryConsumption) + override def estimatedSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + SizeEstimator.estimate(hashTable) + } } override def get(key: InternalRow): Seq[InternalRow] = { From 94624eacb0fdbbe210894151a956f8150cdf527e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 18 Nov 2015 11:53:28 -0800 Subject: [PATCH 1421/1824] [SPARK-11739][SQL] clear the instantiated SQLContext Currently, if the first SQLContext is not removed after stopping SparkContext, a SQLContext could set there forever. This patch make this more robust. Author: Davies Liu Closes #9706 from davies/clear_context. --- .../scala/org/apache/spark/sql/SQLContext.scala | 17 +++++++++++------ .../spark/sql/MultiSQLContextsSuite.scala | 5 ++--- .../execution/ExchangeCoordinatorSuite.scala | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cd1fdc4edb39..39471d2fb79a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1229,7 +1229,7 @@ class SQLContext private[sql]( // construction of the instance. sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - SQLContext.clearInstantiatedContext(self) + SQLContext.clearInstantiatedContext() } }) @@ -1270,13 +1270,13 @@ object SQLContext { */ def getOrCreate(sparkContext: SparkContext): SQLContext = { val ctx = activeContext.get() - if (ctx != null) { + if (ctx != null && !ctx.sparkContext.isStopped) { return ctx } synchronized { val ctx = instantiatedContext.get() - if (ctx == null) { + if (ctx == null || ctx.sparkContext.isStopped) { new SQLContext(sparkContext) } else { ctx @@ -1284,12 +1284,17 @@ object SQLContext { } } - private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(sqlContext, null) + private[sql] def clearInstantiatedContext(): Unit = { + instantiatedContext.set(null) } private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(null, sqlContext) + synchronized { + val ctx = instantiatedContext.get() + if (ctx == null || ctx.sparkContext.isStopped) { + instantiatedContext.set(sqlContext) + } + } } private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 0e8fcb6a858b..34c5c68fd1c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -31,7 +31,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() sparkConf = new SparkConf(false) .setMaster("local[*]") @@ -89,10 +89,9 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { testNewSession(rootSQLContext) testNewSession(rootSQLContext) testCreatingNewSQLContext(allowMultipleSQLContexts) - - SQLContext.clearInstantiatedContext(rootSQLContext) } finally { sc.stop() + SQLContext.clearInstantiatedContext() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 25f2f5caeed1..b96d50a70b85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -34,7 +34,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() } override protected def afterAll(): Unit = { From 31921e0f0bd559d042148d1ea32f865fb3068f38 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 18 Nov 2015 12:09:54 -0800 Subject: [PATCH 1422/1824] [SPARK-4557][STREAMING] Spark Streaming foreachRDD Java API method should accept a VoidFunction<...> Currently streaming foreachRDD Java API uses a function prototype requiring a return value of null. This PR deprecates the old method and uses VoidFunction to allow for more concise declaration. Also added VoidFunction2 to Java API in order to use in Streaming methods. Unit test is added for using foreachRDD with VoidFunction, and changes have been tested with Java 7 and Java 8 using lambdas. Author: Bryan Cutler Closes #9488 from BryanCutler/foreachRDD-VoidFunction-SPARK-4557. --- .../api/java/function/VoidFunction2.java | 27 ++++++++++++ .../apache/spark/streaming/Java8APISuite.java | 26 ++++++++++++ project/MimaExcludes.scala | 4 ++ .../streaming/api/java/JavaDStreamLike.scala | 24 ++++++++++- .../apache/spark/streaming/JavaAPISuite.java | 41 ++++++++++++++++++- 5 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java new file mode 100644 index 000000000000..6c576ab67845 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A two-argument function that takes arguments of type T1 and T2 with no return value. + */ +public interface VoidFunction2 extends Serializable { + public void call(T1 v1, T2 v2) throws Exception; +} diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 163ae92c12c6..4eee97bc8961 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -28,6 +28,7 @@ import org.junit.Assert; import org.junit.Test; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; @@ -360,6 +361,31 @@ public void testFlatMap() { assertOrderInvariantEquals(expected, result); } + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(rdd -> { + accumRdd.add(1); + rdd.foreach(x -> accumEle.add(1)); + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD((rdd, time) -> null); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @Test public void testPairFlatMap() { List> inputData = Arrays.asList( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eb70d27c34c2..bb45d1bb1214 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -142,6 +142,10 @@ object MimaExcludes { "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") + ) ++ Seq( + // SPARK-4557 Changed foreachRDD to use VoidFunction + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") ) case v if v.startsWith("1.5") => Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index edfa474677f1..84acec7d8e33 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaRDDLike} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, _} +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, VoidFunction => JVoidFunction, VoidFunction2 => JVoidFunction2, _} import org.apache.spark.rdd.RDD import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ @@ -308,7 +308,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction[R])", "1.6.0") def foreachRDD(foreachFunc: JFunction[R, Void]) { dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) } @@ -316,11 +319,30 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction2) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction2[R, Time])", "1.6.0") def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction[R]) { + dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction2[R, Time]) { + dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) + } + /** * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index c5217149224e..609bb4413b6b 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -37,7 +37,9 @@ import com.google.common.io.Files; import com.google.common.collect.Sets; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -45,7 +47,6 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -768,6 +769,44 @@ public Iterable call(String x) { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + accumRdd.add(1); + rdd.foreach(new VoidFunction() { + @Override + public void call(Integer i) { + accumEle.add(1); + } + }); + } + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD(new VoidFunction2, Time>() { + @Override + public void call(JavaRDD rdd, Time time) { + } + }); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @SuppressWarnings("unchecked") @Test public void testPairFlatMap() { From a416e41e285700f861559d710dbf857405bfddf6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 12:50:29 -0800 Subject: [PATCH 1423/1824] [SPARK-11809] Switch the default Mesos mode to coarse-grained mode Based on my conversions with people, I believe the consensus is that the coarse-grained mode is more stable and easier to reason about. It is best to use that as the default rather than the more flaky fine-grained mode. Author: Reynold Xin Closes #9795 from rxin/SPARK-11809. --- .../scala/org/apache/spark/SparkContext.scala | 2 +- docs/job-scheduling.md | 2 +- docs/running-on-mesos.md | 27 ++++++++++++------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b5645b08f92d..ab374cb71286 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2710,7 +2710,7 @@ object SparkContext extends Logging { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) - val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) + val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index a3c34cb6796f..36327c6efeaf 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -47,7 +47,7 @@ application is not running tasks on a machine, other applications may run tasks is useful when you expect large numbers of not overly active applications, such as shell sessions from separate users. However, it comes with a risk of less predictable latency, because it may take a while for an application to gain back cores on one node when it has work to do. To use this mode, simply use a -`mesos://` URL without setting `spark.mesos.coarse` to true. +`mesos://` URL and set `spark.mesos.coarse` to false. Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5be208cf3461..a197d0e37302 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -161,21 +161,15 @@ Note that jars or python files that are passed to spark-submit should be URIs re # Mesos Run Modes -Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained". +Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained". -In "fine-grained" mode (default), each Spark task runs as a separate Mesos task. This allows -multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, -where each application gets more or fewer machines as it ramps up and down, but it comes with an -additional overhead in launching each task. This mode may be inappropriate for low-latency -requirements like interactive queries or serving web requests. - -The "coarse-grained" mode will instead launch only *one* long-running Spark task on each Mesos +The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup overhead, but at the cost of reserving the Mesos resources for the complete duration of the application. -To run in coarse-grained mode, set the `spark.mesos.coarse` property in your -[SparkConf](configuration.html#spark-properties): +Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true +to turn it on explictly in [SparkConf](configuration.html#spark-properties): {% highlight scala %} conf.set("spark.mesos.coarse", "true") @@ -186,6 +180,19 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows +multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, +where each application gets more or fewer machines as it ramps up and down, but it comes with an +additional overhead in launching each task. This mode may be inappropriate for low-latency +requirements like interactive queries or serving web requests. + +To run in coarse-grained mode, set the `spark.mesos.coarse` property to false in your +[SparkConf](configuration.html#spark-properties): + +{% highlight scala %} +conf.set("spark.mesos.coarse", "false") +{% endhighlight %} + You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} From 7c5b641808740ba5eed05ba8204cdbaf3fc579f5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 18 Nov 2015 12:53:22 -0800 Subject: [PATCH 1424/1824] [SPARK-10745][CORE] Separate configs between shuffle and RPC [SPARK-6028](https://issues.apache.org/jira/browse/SPARK-6028) uses network module to implement RPC. However, there are some configurations named with `spark.shuffle` prefix in the network module. This PR refactors them to make sure the user can control them in shuffle and RPC separately. The user can use `spark.rpc.*` to set the configuration for netty RPC. Author: Shixiong Zhu Closes #9481 from zsxwing/SPARK-10745. --- .../spark/deploy/ExternalShuffleService.scala | 3 +- .../netty/NettyBlockTransferService.scala | 2 +- .../network/netty/SparkTransportConf.scala | 12 ++-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 8 +-- .../mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../shuffle/FileShuffleBlockResolver.scala | 2 +- .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../spark/network/util/TransportConf.java | 65 ++++++++++++++----- .../network/ChunkFetchIntegrationSuite.java | 2 +- .../RequestTimeoutIntegrationSuite.java | 2 +- .../spark/network/RpcIntegrationSuite.java | 2 +- .../org/apache/spark/network/StreamSuite.java | 2 +- .../network/TransportClientFactorySuite.java | 6 +- .../spark/network/sasl/SparkSaslSuite.java | 6 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../ExternalShuffleBlockResolverSuite.java | 2 +- .../shuffle/ExternalShuffleCleanupSuite.java | 2 +- .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/ExternalShuffleSecuritySuite.java | 2 +- .../shuffle/RetryingBlockFetcherSuite.java | 2 +- .../network/yarn/YarnShuffleService.java | 2 +- 23 files changed, 84 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index a039d543c35e..e8a1e35c3fc4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -45,7 +45,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) private val useSasl: Boolean = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) + private val transportConf = + SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler, true) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 70a42f9045e6..b0694e3c6c8a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -41,7 +41,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index cef203006d68..84833f59d7af 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -40,23 +40,23 @@ object SparkTransportConf { /** * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param _conf the [[SparkConf]] + * @param module the module name * @param numUsableCores if nonzero, this will restrict the server and client threads to only * use the given number of cores, rather than all of the machine's cores. * This restriction will only occur if these properties are not already set. */ - def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = { val conf = _conf.clone // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily // assuming we have all the machine's cores). // NB: Only set if serverThreads/clientThreads not already set. val numThreads = defaultNumThreads(numUsableCores) - conf.set("spark.shuffle.io.serverThreads", - conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) - conf.set("spark.shuffle.io.clientThreads", - conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString) + conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString) - new TransportConf(new ConfigProvider { + new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) }) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 09093819bb22..3e0c49796950 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -22,16 +22,13 @@ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy +import javax.annotation.Nullable -import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal -import com.google.common.base.Preconditions import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -49,7 +46,8 @@ private[netty] class NettyRpcEnv( securityManager: SecurityManager) extends RpcEnv(conf) with Logging { private val transportConf = SparkTransportConf.fromSparkConf( - conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), + "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 2de9b6a65169..7d08eae0b487 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -109,7 +109,7 @@ private[spark] class CoarseMesosSchedulerBackend( private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf), + SparkTransportConf.fromSparkConf(conf, "shuffle"), securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled())) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 39fadd878351..cc5f933393ad 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -46,7 +46,7 @@ private[spark] trait ShuffleWriterGroup { private[spark] class FileShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver with Logging { - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") private lazy val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 05b1eed7f3be..fadb8fe7ed0a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -47,7 +47,7 @@ private[spark] class IndexShuffleBlockResolver( private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 661c706af32b..ab0007fb7899 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -122,7 +122,7 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled()) } else { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 231f4631e0a4..1c775bcb3d9c 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -35,7 +35,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { - val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) + val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 3b2eff377955..115135d44adb 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -23,18 +23,53 @@ * A central location that tracks all the settings we expose to users. */ public class TransportConf { + + private final String SPARK_NETWORK_IO_MODE_KEY; + private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; + private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; + private final String SPARK_NETWORK_IO_BACKLOG_KEY; + private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY; + private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY; + private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY; + private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY; + private final String SPARK_NETWORK_IO_SENDBUFFER_KEY; + private final String SPARK_NETWORK_SASL_TIMEOUT_KEY; + private final String SPARK_NETWORK_IO_MAXRETRIES_KEY; + private final String SPARK_NETWORK_IO_RETRYWAIT_KEY; + private final String SPARK_NETWORK_IO_LAZYFD_KEY; + private final ConfigProvider conf; - public TransportConf(ConfigProvider conf) { + private final String module; + + public TransportConf(String module, ConfigProvider conf) { + this.module = module; this.conf = conf; + SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode"); + SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs"); + SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout"); + SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog"); + SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer"); + SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads"); + SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads"); + SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer"); + SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer"); + SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout"); + SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries"); + SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait"); + SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD"); + } + + private String getConfKey(String suffix) { + return "spark." + module + "." + suffix; } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { - return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true); } /** Connect timeout in milliseconds. Default 120 secs. */ @@ -42,23 +77,23 @@ public int connectionTimeoutMs() { long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( conf.get("spark.network.timeout", "120s")); long defaultTimeoutMs = JavaUtils.timeStringAsSec( - conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000; return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ public int numConnectionsPerPeer() { - return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 1); + return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1); } /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ - public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } + public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); } /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } + public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); } /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } + public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } /** * Receive buffer size (SO_RCVBUF). @@ -67,28 +102,28 @@ public int numConnectionsPerPeer() { * Assuming latency = 1ms, network_bandwidth = 10Gbps * buffer size should be ~ 1.25MB */ - public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } + public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); } /** Send buffer size (SO_SNDBUF). */ - public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. * If set to 0, we will not do any retries. */ - public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); } /** * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ public int ioRetryWaitTimeMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000; } /** @@ -101,11 +136,11 @@ public int memoryMapBytes() { } /** - * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are * created only when data is going to be transferred. This can reduce the number of open files. */ public boolean lazyFileDescriptor() { - return conf.getBoolean("spark.shuffle.io.lazyFD", true); + return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true); } /** diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dfb7740344ed..dc5fa1cee69b 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -83,7 +83,7 @@ public static void setUp() throws Exception { fp.write(fileContent); fp.close(); - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 84ebb337e6d5..42955ef69235 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -60,7 +60,7 @@ public class RequestTimeoutIntegrationSuite { public void setUp() throws Exception { Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.connectionTimeout", "2s"); - conf = new TransportConf(new MapConfigProvider(configMap)); + conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); defaultManager = new StreamManager() { @Override diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 64b457b4b3f0..8eb56bdd9846 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -49,7 +49,7 @@ public class RpcIntegrationSuite { @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 6dcec831dec7..00158fd08162 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -89,7 +89,7 @@ public static void setUp() throws Exception { fp.close(); } - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); final StreamManager streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index f44713741930..dac7d4a5b0a0 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -52,7 +52,7 @@ public class TransportClientFactorySuite { @Before public void setUp() { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); RpcHandler rpcHandler = new NoOpRpcHandler(); context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); @@ -76,7 +76,7 @@ private void testClientReuse(final int maxConnections, boolean concurrent) Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); - TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); @@ -182,7 +182,7 @@ public void closeBlockClientsWithFactory() throws IOException { @Test public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { + TransportConf conf = new TransportConf("shuffle", new ConfigProvider() { @Override public String get(String name) { diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 3469e84e7f4d..b14689967018 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -207,7 +207,7 @@ public void testEncryptedMessage() throws Exception { public void testEncryptedMessageChunking() throws Exception { File file = File.createTempFile("sasltest", ".txt"); try { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); byte[] data = new byte[8 * 1024]; new Random().nextBytes(data); @@ -242,7 +242,7 @@ public void testFileRegionEncryption() throws Exception { final File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); StreamManager sm = mock(StreamManager.class); when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { @Override @@ -368,7 +368,7 @@ private static class SaslTestCtx { boolean disableClientEncryption) throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c393a5e1e681..1c2fa4d0d462 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -70,7 +70,7 @@ public class SaslIntegrationSuite { @BeforeClass public static void beforeAll() throws IOException { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); context = new TransportContext(conf, new TestRpcHandler()); secretKeyHolder = mock(SecretKeyHolder.class); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 3c6cb367dea4..a9958232a1d2 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -42,7 +42,7 @@ public class ExternalShuffleBlockResolverSuite { static TestShuffleDataContext dataContext; - static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @BeforeClass public static void beforeAll() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 2f4f1d0df478..532d7ab8d01b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -35,7 +35,7 @@ public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a3f9a38b1aeb..2095f41d79c1 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -91,7 +91,7 @@ public static void beforeAll() throws IOException { dataContext1.create(); dataContext1.insertHashShuffleData(1, 0, exec1Blocks); - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index aa99efda9494..08ddb3755bd0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -39,7 +39,7 @@ public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); TransportServer server; @Before diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 06e46f924109..3a6ef0d3f847 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -254,7 +254,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); Stubber stub = null; diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 11ea7f3fd3cf..ba6d30a74c67 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -120,7 +120,7 @@ protected void serviceInit(Configuration conf) { registeredExecutorFile = findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); - TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); From 09ad9533d5760652de59fa4830c24cb8667958ac Mon Sep 17 00:00:00 2001 From: JihongMa Date: Wed, 18 Nov 2015 13:03:37 -0800 Subject: [PATCH 1425/1824] [SPARK-11720][SQL][ML] Handle edge cases when count = 0 or 1 for Stats function return Double.NaN for mean/average when count == 0 for all numeric types that is converted to Double, Decimal type continue to return null. Author: JihongMa Closes #9705 from JihongMA/SPARK-11720. --- python/pyspark/sql/dataframe.py | 2 +- .../aggregate/CentralMomentAgg.scala | 2 +- .../expressions/aggregate/Kurtosis.scala | 9 +++++---- .../expressions/aggregate/Skewness.scala | 9 +++++---- .../expressions/aggregate/Stddev.scala | 18 ++++++++++++++---- .../expressions/aggregate/Variance.scala | 18 ++++++++++++++---- .../spark/sql/DataFrameAggregateSuite.scala | 18 ++++++++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 8 files changed, 53 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ad6ad0235a90..0dd75ba7ca82 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -761,7 +761,7 @@ def describe(self, *cols): +-------+------------------+-----+ | count| 2| 2| | mean| 3.5| null| - | stddev|2.1213203435596424| NaN| + | stddev|2.1213203435596424| null| | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index de5872ab11eb..d07d4c338cdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) * needed to compute the aggregate stat. */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any override final def eval(buffer: InternalRow): Any = { val n = buffer.getDouble(nOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 8fa3aac9f1a5..c2bf2cb94116 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -37,16 +37,17 @@ case class Kurtosis(child: Expression, override protected val momentOrder = 4 // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m4 = moments(4) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { n * m4 / (m2 * m2) - 3.0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index e1c01a5b8278..9411bcea2539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -36,16 +36,17 @@ case class Skewness(child: Expression, override protected val momentOrder = 3 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m3 = moments(3) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 05dd5e3b2254..eec79a9033e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -36,11 +36,17 @@ case class StddevSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + math.sqrt(moments(2) / (n - 1.0)) + } } } @@ -62,10 +68,14 @@ case class StddevPop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) + if (n == 0.0) { + null + } else { + math.sqrt(moments(2) / n) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index ede2da280596..cf3a74030539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -36,11 +36,17 @@ case class VarianceSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + moments(2) / (n - 1.0) + } } } @@ -62,10 +68,14 @@ case class VariancePop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else moments(2) / n + if (n == 0.0) { + null + } else { + moments(2) / n + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 432e8d17623a..71adf2148a40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), - Row(Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null)) } test("zero sum") { @@ -244,17 +244,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero moments") { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a), + var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) checkAnswer( input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), expr("variance(a)"), expr("var_samp(a)"), expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) } test("null moments") { @@ -262,7 +268,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) checkAnswer( emptyTableData.agg( @@ -271,6 +277,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5a7f24684d1b..6399b0165c4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val emptyDescribeResult = Seq( Row("count", "0", "0"), Row("mean", null, null), - Row("stddev", "NaN", "NaN"), + Row("stddev", null, null), Row("min", null, null), Row("max", null, null)) From 045a4f045821dcf60442f0600c2df1b79bddb536 Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Wed, 18 Nov 2015 13:06:25 -0800 Subject: [PATCH 1426/1824] [SPARK-6790][ML] Add spark.ml LinearRegression import/export This replaces [https://github.com/apache/spark/pull/9656] with updates. fayeshine should be the main author when this PR is committed. CC: mengxr fayeshine Author: Wenjian Huang Author: Joseph K. Bradley Closes #9814 from jkbradley/fayeshine-patch-6790. --- .../ml/regression/LinearRegression.scala | 77 ++++++++++++++++++- .../ml/regression/LinearRegressionSuite.scala | 34 +++++++- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 913140e58198..ca55d5915e68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.ml.feature.Instance @@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -65,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Logging { + with LinearRegressionParams with Writable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -341,6 +342,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object LinearRegression extends Readable[LinearRegression] { + + @Since("1.6.0") + override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] + + @Since("1.6.0") + override def load(path: String): LinearRegression = read.load(path) } /** @@ -354,7 +368,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with Writable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -422,6 +436,63 @@ class LinearRegressionModel private[ml] ( if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) +} + +@Since("1.6.0") +object LinearRegressionModel extends Readable[LinearRegressionModel] { + + @Since("1.6.0") + override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LinearRegressionModel = read.load(path) + + /** [[Writer]] instance for [[LinearRegressionModel]] */ + private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) + extends Writer with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } + } + + private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index a1d86fe8feda..2bdc0e184d73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,14 +22,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } + + test("read/write") { + def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + } + val lr = new LinearRegression() + testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, + checkModelData) + } +} + +object LinearRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "solver" -> "l-bfgs" + ) } From 2acdf10b1f3bb1242dba64efa798c672fde9f0d2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 13:16:31 -0800 Subject: [PATCH 1427/1824] [SPARK-6789][ML] Add Readable, Writable support for spark.ml ALS, ALSModel Also modifies DefaultParamsWriter.saveMetadata to take optional extra metadata. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9786 from jkbradley/als-io. --- .../apache/spark/ml/recommendation/ALS.scala | 75 ++++++++++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 14 +++- .../spark/ml/recommendation/ALSSuite.scala | 78 ++++++++++++++++--- 3 files changed, 150 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 535f266b9a94..d92514d2e239 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,13 +27,16 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s.{DefaultFormats, JValue} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD @@ -182,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams { + extends Model[ALSModel] with ALSModelParams with Writable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -220,8 +223,60 @@ class ALSModel private[ml] ( val copied = new ALSModel(uid, rank, userFactors, itemFactors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new ALSModel.ALSModelWriter(this) } +@Since("1.6.0") +object ALSModel extends Readable[ALSModel] { + + @Since("1.6.0") + override def read: Reader[ALSModel] = new ALSModelReader + + @Since("1.6.0") + override def load(path: String): ALSModel = read.load(path) + + private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata = render("rank" -> instance.rank) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val userPath = new Path(path, "userFactors").toString + instance.userFactors.write.format("parquet").save(userPath) + val itemPath = new Path(path, "itemFactors").toString + instance.itemFactors.write.format("parquet").save(itemPath) + } + } + + private[recommendation] class ALSModelReader extends Reader[ALSModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.recommendation.ALSModel" + + override def load(path: String): ALSModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + implicit val format = DefaultFormats + val rank: Int = metadata.extraMetadata match { + case Some(m: JValue) => + (m \ "rank").extract[Int] + case None => + throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + + s" ${metadata.metadataStr}") + } + + val userPath = new Path(path, "userFactors").toString + val userFactors = sqlContext.read.format("parquet").load(userPath) + val itemPath = new Path(path, "itemFactors").toString + val itemFactors = sqlContext.read.format("parquet").load(itemPath) + + val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} /** * :: Experimental :: @@ -254,7 +309,7 @@ class ALSModel private[ml] ( * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -336,8 +391,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { } override def copy(extra: ParamMap): ALS = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } + /** * :: DeveloperApi :: * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is @@ -347,7 +406,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { * than 2 billion. */ @DeveloperApi -object ALS extends Logging { +object ALS extends Readable[ALS] with Logging { /** * :: DeveloperApi :: @@ -356,6 +415,12 @@ object ALS extends Logging { @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) + @Since("1.6.0") + override def read: Reader[ALS] = new DefaultParamsReader[ALS] + + @Since("1.6.0") + override def load(path: String): ALS = read.load(path) + /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { /** Solves a least squares problem with regularization (possibly with other constraints). */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index dddb72af5ba7..d8ce907af532 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -194,7 +194,11 @@ private[ml] object DefaultParamsWriter { * - uid * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ - def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { + def saveMetadata( + instance: Params, + path: String, + sc: SparkContext, + extraMetadata: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -205,7 +209,8 @@ private[ml] object DefaultParamsWriter { ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("extraMetadata" -> extraMetadata) val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -236,6 +241,7 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] + * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] * @param metadataStr Full metadata file String (for debugging) */ case class Metadata( @@ -244,6 +250,7 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + extraMetadata: Option[JValue], metadataStr: String) /** @@ -262,12 +269,13 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" + val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index eadc80e0e62b..2c3fb84160dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.recommendation -import java.io.File import java.util.Random import scala.collection.mutable @@ -26,28 +25,26 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.{DataFrame, Row} -class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - private var tempDir: File = _ +class ALSSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() - tempDir = Utils.createTempDir() sc.setCheckpointDir(tempDir.getAbsolutePath) } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) super.afterAll() } @@ -186,7 +183,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] var i = 0 - while (i < compressed.srcIds.size) { + while (i < compressed.srcIds.length) { var j = compressed.dstPtrs(i) while (j < compressed.dstPtrs(i + 1)) { val dstEncodedIndex = compressed.dstEncodedIndices(j) @@ -483,4 +480,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true, seed = 0) } + + test("read/write") { + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val als = new ALS() + allEstimatorParamSettings.foreach { case (p, v) => + als.set(als.getParam(p), v) + } + val sqlContext = this.sqlContext + import sqlContext.implicits._ + val model = als.fit(ratings.toDF()) + + // Test Estimator save/load + val als2 = testDefaultReadWrite(als) + allEstimatorParamSettings.foreach { case (p, v) => + val param = als.getParam(p) + assert(als.get(param).get === als2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + allModelParamSettings.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + assert(model.rank === model2.rank) + def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { + df.select("id", "features").collect().map { case r => + (r.getInt(0), r.getAs[Array[Float]](1)) + }.toSet + } + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } +} + +object ALSSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allModelParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPredictionCol" + ) + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map( + "maxIter" -> 1, + "rank" -> 1, + "regParam" -> 0.01, + "numUserBlocks" -> 2, + "numItemBlocks" -> 2, + "implicitPrefs" -> true, + "alpha" -> 0.9, + "nonnegative" -> true, + "checkpointInterval" -> 20 + ) } From e391abdf2cb6098a35347bd123b815ee9ac5b689 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 13:25:15 -0800 Subject: [PATCH 1428/1824] [SPARK-11813][MLLIB] Avoid serialization of vocab in Word2Vec jira: https://issues.apache.org/jira/browse/SPARK-11813 I found the problem during training a large corpus. Avoid serialization of vocab in Word2Vec has 2 benefits. 1. Performance improvement for less serialization. 2. Increase the capacity of Word2Vec a lot. Currently in the fit of word2vec, the closure mainly includes serialization of Word2Vec and 2 global table. the main part of Word2vec is the vocab of size: vocab * 40 * 2 * 4 = 320 vocab 2 global table: vocab * vectorSize * 8. If vectorSize = 20, that's 160 vocab. Their sum cannot exceed Int.max due to the restriction of ByteArrayOutputStream. In any case, avoiding serialization of vocab helps decrease the size of the closure serialization, especially when vectorSize is small, thus to allow larger vocabulary. Actually there's another possible fix, make local copy of fields to avoid including Word2Vec in the closure. Let me know if that's preferred. Author: Yuhao Yang Closes #9803 from hhbyyh/w2vVocab. --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f3e4d346e358..7ab0d89d23a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -145,8 +145,8 @@ class Word2Vec extends Serializable with Logging { private var trainWordsCount = 0 private var vocabSize = 0 - private var vocab: Array[VocabWord] = null - private var vocabHash = mutable.HashMap.empty[String, Int] + @transient private var vocab: Array[VocabWord] = null + @transient private var vocabHash = mutable.HashMap.empty[String, Int] private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) From e222d758499ad2609046cc1a2cc8afb45c5bccbb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:30:29 -0800 Subject: [PATCH 1429/1824] [SPARK-11684][R][ML][DOC] Update SparkR glm API doc, user guide and example codes This PR includes: * Update SparkR:::glm, SparkR:::summary API docs. * Update SparkR machine learning user guide and example codes to show: * supporting feature interaction in R formula. * summary for gaussian GLM model. * coefficients for binomial GLM model. mengxr Author: Yanbo Liang Closes #9727 from yanboliang/spark-11684. --- R/pkg/R/mllib.R | 18 +++++-- docs/sparkr.md | 50 ++++++++++++++++--- .../ml/regression/LinearRegression.scala | 3 ++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index f23e1c7f1fce..8d3b4388ae57 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -32,6 +32,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter #' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @param standardize Whether to standardize features before training +#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and +#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory +#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an +#' analytical solution to the linear regression problem. The default value is "auto" +#' which means that the solver algorithm is selected automatically. #' @return a fitted MLlib model #' @rdname glm #' @export @@ -79,9 +85,15 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param x A fitted MLlib model -#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See -#' summary.glm for more information. +#' @param object A fitted MLlib model +#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family +#' or a list with 'coefficients' component for binomial family. \cr +#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals +#' of the estimation, the 'coefficients' gives the estimated coefficients and their +#' estimated standard errors, t values and p-values. (It only available when model +#' fitted by normal solver.) \cr +#' For binomial family: the 'coefficients' gives the estimated coefficients. +#' See summary.glm for more information. \cr #' @rdname summary #' @export #' @examples diff --git a/docs/sparkr.md b/docs/sparkr.md index 437bd4756c27..a744b76be746 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,24 +286,37 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. + +The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). + +* For gaussian GLM model, it returns a list with 'devianceResiduals' and 'coefficients' components. The 'devianceResiduals' gives the min/max deviance residuals of the estimation; the 'coefficients' gives the estimated coefficients and their estimated standard errors, t values and p-values. (It only available when model fitted by normal solver.) +* For binomial GLM model, it returns a list with 'coefficients' component which gives the estimated coefficients. + +The examples below show the use of building gaussian GLM model and binomial GLM model using SparkR. + +## Gaussian GLM model
    {% highlight r %} # Create the DataFrame df <- createDataFrame(sqlContext, iris) -# Fit a linear model over the dataset. +# Fit a gaussian GLM model over the dataset. model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") -# Model coefficients are returned in a similar format to R's native glm(). +# Model summary are returned in a similar format to R's native glm(). summary(model) +##$devianceResiduals +## Min Max +## -1.307112 1.412532 +## ##$coefficients -## Estimate -##(Intercept) 2.2513930 -##Sepal_Width 0.8035609 -##Species_versicolor 1.4587432 -##Species_virginica 1.9468169 +## Estimate Std. Error t value Pr(>|t|) +##(Intercept) 2.251393 0.3697543 6.08889 9.568102e-09 +##Sepal_Width 0.8035609 0.106339 7.556598 4.187317e-12 +##Species_versicolor 1.458743 0.1121079 13.01195 0 +##Species_virginica 1.946817 0.100015 19.46525 0 # Make predictions based on the model. predictions <- predict(model, newData = df) @@ -317,3 +330,24 @@ head(select(predictions, "Sepal_Length", "prediction")) ##6 5.4 5.385281 {% endhighlight %}
    + +## Binomial GLM model + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) +training <- filter(df, df$Species != "setosa") + +# Fit a binomial GLM model over the dataset. +model <- glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) -13.046005 +##Sepal_Length 1.902373 +##Sepal_Width 0.404655 +{% endhighlight %} +
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ca55d5915e68..f7c44f0a51b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -145,6 +145,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Set the solver algorithm used for optimization. * In case of linear regression, this can be "l-bfgs", "normal" and "auto". + * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. "normal" denotes using Normal Equation as an analytical + * solution to the linear regression problem. * The default value is "auto" which means that the solver algorithm is * selected automatically. * @group setParam From 603a721c21488e17c15c45ce1de893e6b3d02274 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:32:06 -0800 Subject: [PATCH 1430/1824] [SPARK-11820][ML][PYSPARK] PySpark LiR & LoR should support weightCol [SPARK-7685](https://issues.apache.org/jira/browse/SPARK-7685) and [SPARK-9642](https://issues.apache.org/jira/browse/SPARK-9642) have already supported setting weight column for ```LogisticRegression``` and ```LinearRegression```. It's a very important feature, PySpark should also support. mengxr Author: Yanbo Liang Closes #9811 from yanboliang/spark-11820. --- python/pyspark/ml/classification.py | 17 +++++++++-------- python/pyspark/ml/regression.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 603f2c7f798d..4a2982e2047f 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -36,7 +36,8 @@ @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, - HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, + HasWeightCol): """ Logistic regression. Currently, this class only supports binary classification. @@ -44,9 +45,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors >>> df = sc.parallelize([ - ... Row(label=1.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) >>> model.weights DenseVector([5.5...]) @@ -80,12 +81,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() @@ -105,12 +106,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 7648bf13266b..944e648ec880 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -35,7 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver): + HasStandardization, HasSolver, HasWeightCol): """ Linear regression. @@ -50,9 +50,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ - ... (1.0, Vectors.dense(1.0)), - ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") + ... (1.0, 2.0, Vectors.dense(1.0)), + ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) + >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 @@ -75,11 +75,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -92,11 +92,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs From 54db79702513e11335c33bcf3a03c59e965e6f16 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 18 Nov 2015 14:05:18 -0800 Subject: [PATCH 1431/1824] [SPARK-11544][SQL] sqlContext doesn't use PathFilter Apply the user supplied pathfilter while retrieving the files from fs. Author: Dilip Biswal Closes #9652 from dilipbiswal/spark-11544. --- .../apache/spark/sql/sources/interfaces.scala | 25 ++++++++++--- .../datasources/json/JsonSuite.scala | 36 +++++++++++++++++-- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3d3bdf50df6..f9465157c936 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -447,9 +448,15 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -847,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6042b1178aff..f09b61e83815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,19 +19,27 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1390,4 +1398,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } } From 5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 15:42:07 -0800 Subject: [PATCH 1432/1824] [SPARK-11810][SQL] Java-based encoder for opaque types in Datasets. This patch refactors the existing Kryo encoder expressions and adds support for Java serialization. Author: Reynold Xin Closes #9802 from rxin/SPARK-11810. --- .../scala/org/apache/spark/sql/Encoder.scala | 41 +++++++++--- .../sql/catalyst/expressions/objects.scala | 67 ++++++++++++------- .../catalyst/encoders/FlatEncoderSuite.scala | 27 ++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 36 +++++++++- 4 files changed, 130 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 79c2255641c0..1ed5111440c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo} +import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** @@ -43,28 +43,49 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - */ - def kryo[T: ClassTag]: Encoder[T] = { - val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true)) - val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T]) + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { ExpressionEncoder[T]( schema = new StructType().add("value", BinaryType), flat = true, - toRowExpressions = Seq(ser), - fromRowExpression = deser, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), clsTag = classTag[T] ) } + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 489c6126f8cd..acf0da240051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -21,7 +21,7 @@ import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark.SparkConf -import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer} +import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} @@ -517,29 +517,39 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy } } -/** Serializes an input object using Kryo serializer. */ -case class SerializeWithKryo(child: Expression) extends UnaryExpression { +/** + * Serializes an input object using a generic serializer (Kryo or Java). + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $kryo.serialize(${input.value}, null).array(); + ${ev.value} = $serializer.serialize(${input.value}, null).array(); } """ } @@ -548,29 +558,38 @@ case class SerializeWithKryo(child: Expression) extends UnaryExpression { } /** - * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit - * parameter because TreeNode cannot copy implicit parameters. + * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag + * is not an implicit parameter because TreeNode cannot copy implicit parameters. + * @param kryo if true, use Kryo. Otherwise, use Java. */ -case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression { +case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) + extends UnaryExpression { override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = (${ctx.javaType(dataType)}) - $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 2729db84897a..6e0322fb6e01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -76,17 +76,34 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { // Kryo encoders encodeDecodeTest( "hello", - Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]], + encoderFor(Encoders.kryo[String]), "kryo string") encodeDecodeTest( - new NotJavaSerializable(15), - Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]], + new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object serialization") + + // Java encoders + encodeDecodeTest( + "hello", + encoderFor(Encoders.javaSerialization[String]), + "java string") + encodeDecodeTest( + new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), + "java object serialization") } +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} -class NotJavaSerializable(val value: Int) { +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[NotJavaSerializable].value + this.value == other.asInstanceOf[JavaSerializable].value } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b6db583dfe01..89d964aa3e46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -357,7 +357,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.toString == "[_1: int, _2: int]") } - test("kryo encoder") { + test("Kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() @@ -365,7 +365,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - test("kryo encoder self join") { + test("Kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -375,6 +375,25 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (KryoData(2), KryoData(1)), (KryoData(2), KryoData(2)))) } + + test("Java encoder") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((JavaData(1), 1L), (JavaData(2), 1L))) + } + + ignore("Java encoder self join") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (JavaData(1), JavaData(1)), + (JavaData(1), JavaData(2)), + (JavaData(2), JavaData(1)), + (JavaData(2), JavaData(2)))) + } } @@ -406,3 +425,16 @@ class KryoData(val a: Int) { object KryoData { def apply(a: Int): KryoData = new KryoData(a) } + +/** Used to test Java encoder. */ +class JavaData(val a: Int) extends Serializable { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[JavaData].a + } + override def hashCode: Int = a + override def toString: String = s"JavaData($a)" +} + +object JavaData { + def apply(a: Int): JavaData = new JavaData(a) +} From 7e987de1770f4ab3d54bc05db8de0a1ef035941d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 15:47:49 -0800 Subject: [PATCH 1433/1824] [SPARK-6787][ML] add read/write to estimators under ml.feature (1) Add read/write support to the following estimators under spark.ml: * CountVectorizer * IDF * MinMaxScaler * StandardScaler (a little awkward because we store some params in spark.mllib model) * StringIndexer Added some necessary method for read/write. Maybe we should add `private[ml] trait DefaultParamsReadable` and `DefaultParamsWritable` to save some boilerplate code, though we still need to override `load` for Java compatibility. jkbradley Author: Xiangrui Meng Closes #9798 from mengxr/SPARK-6787. --- .../spark/ml/feature/CountVectorizer.scala | 72 +++++++++++++++-- .../org/apache/spark/ml/feature/IDF.scala | 71 ++++++++++++++++- .../spark/ml/feature/MinMaxScaler.scala | 72 +++++++++++++++-- .../spark/ml/feature/StandardScaler.scala | 78 ++++++++++++++++++- .../spark/ml/feature/StringIndexer.scala | 70 +++++++++++++++-- .../ml/feature/CountVectorizerSuite.scala | 24 +++++- .../apache/spark/ml/feature/IDFSuite.scala | 19 ++++- .../spark/ml/feature/MinMaxScalerSuite.scala | 25 +++++- .../ml/feature/StandardScalerSuite.scala | 64 +++++++++++---- .../spark/ml/feature/StringIndexerSuite.scala | 19 ++++- 10 files changed, 467 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 49028e4b8506..5ff9bfb7d111 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -16,17 +16,19 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.DataFrame import org.apache.spark.util.collection.OpenHashMap /** @@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { def this() = this(Identifiable.randomUID("cntVec")) @@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object CountVectorizer extends Readable[CountVectorizer] { + + @Since("1.6.0") + override def read: Reader[CountVectorizer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): CountVectorizer = super.load(path) } /** @@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String) */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams { + extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + + import CountVectorizerModel._ def this(vocabulary: Array[String]) = { this(Identifiable.randomUID("cntVecModel"), vocabulary) @@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) copyValues(copied, extra) } + + @Since("1.6.0") + override def write: Writer = new CountVectorizerModelWriter(this) +} + +@Since("1.6.0") +object CountVectorizerModel extends Readable[CountVectorizerModel] { + + private[CountVectorizerModel] + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + + private case class Data(vocabulary: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.vocabulary) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + + private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + + override def load(path: String): CountVectorizerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabulary") + .head() + val vocabulary = data.getAs[Seq[String]](0).toArray + val model = new CountVectorizerModel(metadata.uid, vocabulary) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + + @Since("1.6.0") + override def load(path: String): CountVectorizerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 4c36df75d8aa..53ad34ef1264 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { def this() = this(Identifiable.randomUID("idf")) @@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object IDF extends Readable[IDF] { + + @Since("1.6.0") + override def read: Reader[IDF] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): IDF = super.load(path) } /** @@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase { + extends Model[IDFModel] with IDFBase with Writable { + + import IDFModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -117,4 +134,50 @@ class IDFModel private[ml] ( val copied = new IDFModel(uid, idfModel) copyValues(copied, extra).setParent(parent) } + + /** Returns the IDF vector. */ + @Since("1.6.0") + def idf: Vector = idfModel.idf + + @Since("1.6.0") + override def write: Writer = new IDFModelWriter(this) +} + +@Since("1.6.0") +object IDFModel extends Readable[IDFModel] { + + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + + private case class Data(idf: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.idf) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IDFModelReader extends Reader[IDFModel] { + + private val className = "org.apache.spark.ml.feature.IDFModel" + + override def load(path: String): IDFModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("idf") + .head() + val idf = data.getAs[Vector](0) + val model = new IDFModel(metadata.uid, new feature.IDFModel(idf)) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[IDFModel] = new IDFModelReader + + @Since("1.6.0") + override def load(path: String): IDFModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 1b494ec8b172..24d964fae834 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} -import org.apache.spark.ml.util.Identifiable + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ @@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object MinMaxScaler extends Readable[MinMaxScaler] { + + @Since("1.6.0") + override def read: Reader[MinMaxScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): MinMaxScaler = super.load(path) } /** @@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + + import MinMaxScalerModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] ( val copied = new MinMaxScalerModel(uid, originalMin, originalMax) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new MinMaxScalerModelWriter(this) +} + +@Since("1.6.0") +object MinMaxScalerModel extends Readable[MinMaxScalerModel] { + + private[MinMaxScalerModel] + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + + private case class Data(originalMin: Vector, originalMax: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.originalMin, instance.originalMax) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + + private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + + override def load(path: String): MinMaxScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath) + .select("originalMin", "originalMax") + .head() + val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + + @Since("1.6.0") + override def load(path: String): MinMaxScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index f6d0b0c0e9e7..ab04e5418dd4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -57,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams { + with StandardScalerParams with Writable { def this() = this(Identifiable.randomUID("stdScal")) @@ -94,6 +96,19 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StandardScaler extends Readable[StandardScaler] { + + @Since("1.6.0") + override def read: Reader[StandardScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StandardScaler = super.load(path) } /** @@ -104,7 +119,9 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams { + extends Model[StandardScalerModel] with StandardScalerParams with Writable { + + import StandardScalerModel._ /** Standard deviation of the StandardScalerModel */ val std: Vector = scaler.std @@ -112,6 +129,14 @@ class StandardScalerModel private[ml] ( /** Mean of the StandardScalerModel */ val mean: Vector = scaler.mean + /** Whether to scale to unit standard deviation. */ + @Since("1.6.0") + def getWithStd: Boolean = scaler.withStd + + /** Whether to center data with mean. */ + @Since("1.6.0") + def getWithMean: Boolean = scaler.withMean + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -138,4 +163,49 @@ class StandardScalerModel private[ml] ( val copied = new StandardScalerModel(uid, scaler) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new StandardScalerModelWriter(this) +} + +@Since("1.6.0") +object StandardScalerModel extends Readable[StandardScalerModel] { + + private[StandardScalerModel] + class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + + private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StandardScalerModelReader extends Reader[StandardScalerModel] { + + private val className = "org.apache.spark.ml.feature.StandardScalerModel" + + override def load(path: String): StandardScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) = + sqlContext.read.parquet(dataPath) + .select("std", "mean", "withStd", "withMean") + .head() + // This is very likely to change in the future because withStd and withMean should be params. + val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean) + val model = new StandardScalerModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + + @Since("1.6.0") + override def load(path: String): StandardScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f782a272d11d..f16f6afc002d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ @@ -64,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase { + with StringIndexerBase with Writable { def this() = this(Identifiable.randomUID("strIdx")) @@ -92,6 +93,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StringIndexer extends Readable[StringIndexer] { + + @Since("1.6.0") + override def read: Reader[StringIndexer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StringIndexer = super.load(path) } /** @@ -107,7 +121,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod @Experimental class StringIndexerModel ( override val uid: String, - val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) + extends Model[StringIndexerModel] with StringIndexerBase with Writable { + + import StringIndexerModel._ def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) @@ -176,6 +193,49 @@ class StringIndexerModel ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: StringIndexModelWriter = new StringIndexModelWriter(this) +} + +@Since("1.6.0") +object StringIndexerModel extends Readable[StringIndexerModel] { + + private[StringIndexerModel] + class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + + private case class Data(labels: Array[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.labels) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StringIndexerModelReader extends Reader[StringIndexerModel] { + + private val className = "org.apache.spark.ml.feature.StringIndexerModel" + + override def load(path: String): StringIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("labels") + .head() + val labels = data.getAs[Seq[String]](0).toArray + val model = new StringIndexerModel(metadata.uid, labels) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + + @Since("1.6.0") + override def load(path: String): StringIndexerModel = super.load(path) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index e192fa4850af..9c9999017317 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -18,14 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { test("params") { + ParamsSuite.checkParams(new CountVectorizer) ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) } @@ -164,4 +167,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features ~== expected absTol 1e-14) } } + + test("CountVectorizer read/write") { + val t = new CountVectorizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDF(0.5) + .setMinTF(3.0) + .setVocabSize(10) + testDefaultReadWrite(t) + } + + test("CountVectorizerModel read/write") { + val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTF(3.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.vocabulary === instance.vocabulary) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 08f80af03429..bc958c15857b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { @@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("IDF read/write") { + val t = new IDF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDocFreq(5) + testDefaultReadWrite(t) + } + + test("IDFModel read/write") { + val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0))) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.idf === instance.idf) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c04dda41eea3..09183fe65b72 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { val sqlContext = new SQLContext(sc) @@ -69,4 +69,25 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("MinMaxScaler read/write") { + val t = new MinMaxScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMax(1.0) + .setMin(-1.0) + testDefaultReadWrite(t) + } + + test("MinMaxScalerModel read/write") { + val instance = new MinMaxScalerModel( + "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMin(-1.0) + .setMax(1.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.originalMin === instance.originalMin) + assert(newInstance.originalMax === instance.originalMax) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 879a3ae87500..49a4b2efe0c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { @transient var data: Array[Vector] = _ @transient var resWithStd: Array[Vector] = _ @@ -56,23 +60,29 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ ) } - def assertResult(dataframe: DataFrame): Unit = { - dataframe.select("standarded_features", "expected").collect().foreach { + def assertResult(df: DataFrame): Unit = { + df.select("standardized_features", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== vector2 absTol 1E-5, "The vector value is not correct after standardization.") } } + test("params") { + ParamsSuite.checkParams(new StandardScaler) + val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0)) + ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel)) + } + test("Standardization with default parameter") { val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") - val standardscaler0 = new StandardScaler() + val standardScaler0 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .fit(df0) - assertResult(standardscaler0.transform(df0)) + assertResult(standardScaler0.transform(df0)) } test("Standardization with setter") { @@ -80,29 +90,49 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") - val standardscaler1 = new StandardScaler() + val standardScaler1 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(true) .fit(df1) - val standardscaler2 = new StandardScaler() + val standardScaler2 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(false) .fit(df2) - val standardscaler3 = new StandardScaler() + val standardScaler3 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(false) .setWithStd(false) .fit(df3) - assertResult(standardscaler1.transform(df1)) - assertResult(standardscaler2.transform(df2)) - assertResult(standardscaler3.transform(df3)) + assertResult(standardScaler1.transform(df1)) + assertResult(standardScaler2.transform(df2)) + assertResult(standardScaler3.transform(df3)) + } + + test("StandardScaler read/write") { + val t = new StandardScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setWithStd(false) + .setWithMean(true) + testDefaultReadWrite(t) + } + + test("StandardScalerModel read/write") { + val oldModel = new feature.StandardScalerModel( + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true) + val instance = new StandardScalerModel("myStandardScalerModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.std === instance.std) + assert(newInstance.mean === instance.mean) + assert(newInstance.getWithStd === instance.getWithStd) + assert(newInstance.getWithMean === instance.getWithMean) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index be37bfb43883..749bfac74782 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -118,6 +118,23 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexer read/write") { + val t = new StringIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + testDefaultReadWrite(t) + } + + test("StringIndexerModel read/write") { + val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.labels === instance.labels) + } + test("IndexToString params") { val idxToStr = new IndexToString() ParamsSuite.checkParams(idxToStr) @@ -175,7 +192,7 @@ class StringIndexerSuite assert(outSchema("output").dataType === StringType) } - test("read/write") { + test("IndexToString read/write") { val t = new IndexToString() .setInputCol("myInputCol") .setOutputCol("myOutputCol") From 3a9851936ddfe5bcb6a7f364d535fac977551f5d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 15:55:41 -0800 Subject: [PATCH 1434/1824] [SPARK-11649] Properly set Akka frame size in SparkListenerSuite test SparkListenerSuite's _"onTaskGettingResult() called when result fetched remotely"_ test was extremely slow (1 to 4 minutes to run) and recently became extremely flaky, frequently failing with OutOfMemoryError. The root cause was the fact that this was using `System.setProperty` to set the Akka frame size, which was not actually modifying the frame size. As a result, this test would allocate much more data than necessary. The fix here is to simply use SparkConf in order to configure the frame size. Author: Josh Rosen Closes #9822 from JoshRosen/SPARK-11649. --- .../org/apache/spark/scheduler/SparkListenerSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 53102b9f1c93..84e545851f49 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -269,14 +269,15 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("onTaskGettingResult() called when result fetched remotely") { - sc = new SparkContext("local", "SparkListenerSuite") + val conf = new SparkConf().set("spark.akka.frameSize", "1") + sc = new SparkContext("local", "SparkListenerSuite", conf) val listener = new SaveTaskEvents sc.addSparkListener(listener) // Make a task whose result is larger than the akka frame size - System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + assert(akkaFrameSize === 1024 * 1024) val result = sc.parallelize(Seq(1), 1) .map { x => 1.to(akkaFrameSize).toArray } .reduce { case (x, y) => x } From c07a50b86254578625be777b1890ff95e832ac6e Mon Sep 17 00:00:00 2001 From: Derek Dagit Date: Wed, 18 Nov 2015 15:56:54 -0800 Subject: [PATCH 1435/1824] [SPARK-10930] History "Stages" page "duration" can be confusing Author: Derek Dagit Closes #9051 from d2r/spark-10930-ui-max-task-dur. --- .../org/apache/spark/ui/jobs/StageTable.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index ea806d09b600..2a1c3c1a50ec 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -145,9 +145,22 @@ private[ui] class StageTableBase( case None => "Unknown" } val finishTime = s.completionTime.getOrElse(System.currentTimeMillis) - val duration = s.submissionTime.map { t => - if (finishTime > t) finishTime - t else System.currentTimeMillis - t - } + + // The submission time for a stage is misleading because it counts the time + // the stage waits to be launched. (SPARK-10930) + val taskLaunchTimes = + stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + val duration: Option[Long] = + if (taskLaunchTimes.nonEmpty) { + val startTime = taskLaunchTimes.min + if (finishTime > startTime) { + Some(finishTime - startTime) + } else { + Some(System.currentTimeMillis() - startTime) + } + } else { + None + } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val inputRead = stageData.inputBytes From 4b117121900e5f242e7c8f46a69164385f0da7cc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 16:00:35 -0800 Subject: [PATCH 1436/1824] [SPARK-11495] Fix potential socket / file handle leaks that were found via static analysis The HP Fortify Opens Source Review team (https://www.hpfod.com/open-source-review-project) reported a handful of potential resource leaks that were discovered using their static analysis tool. We should fix the issues identified by their scan. Author: Josh Rosen Closes #9455 from JoshRosen/fix-potential-resource-leaks. --- .../spark/unsafe/map/BytesToBytesMap.java | 7 ++++ .../unsafe/sort/UnsafeSorterSpillReader.java | 38 +++++++++++-------- .../streaming/JavaCustomReceiver.java | 31 +++++++-------- .../network/ChunkFetchIntegrationSuite.java | 15 ++++++-- .../shuffle/TestShuffleDataContext.java | 32 ++++++++++------ .../spark/streaming/JavaReceiverAPISuite.java | 20 ++++++---- 6 files changed, 90 insertions(+), 53 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 04694dc54418..3387f9a4177c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -24,6 +24,7 @@ import java.util.LinkedList; import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -272,6 +273,7 @@ private void advanceToNextPage() { } } try { + Closeables.close(reader, /* swallowIOException = */ false); reader = spillWriters.getFirst().getReader(blockManager); recordsInPage = -1; } catch (IOException e) { @@ -318,6 +320,11 @@ public Location next() { try { reader.loadNext(); } catch (IOException e) { + try { + reader.close(); + } catch(IOException e2) { + logger.error("Error while closing spill reader", e2); + } // Scala iterator does not handle exception Platform.throwException(e); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 039e940a357e..dcb13e6581e5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -20,8 +20,7 @@ import java.io.*; import com.google.common.io.ByteStreams; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import com.google.common.io.Closeables; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; @@ -31,10 +30,8 @@ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). */ -public final class UnsafeSorterSpillReader extends UnsafeSorterIterator { - private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); +public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { - private final File file; private InputStream in; private DataInputStream din; @@ -52,11 +49,15 @@ public UnsafeSorterSpillReader( File file, BlockId blockId) throws IOException { assert (file.length() > 0); - this.file = file; final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); - this.in = blockManager.wrapForCompression(blockId, bs); - this.din = new DataInputStream(this.in); - numRecordsRemaining = din.readInt(); + try { + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } catch (IOException e) { + Closeables.close(bs, /* swallowIOException = */ true); + throw e; + } } @Override @@ -75,12 +76,7 @@ public void loadNext() throws IOException { ByteStreams.readFully(in, arr, 0, recordLength); numRecordsRemaining--; if (numRecordsRemaining == 0) { - in.close(); - if (!file.delete() && file.exists()) { - logger.warn("Unable to delete spill file {}", file.getPath()); - } - in = null; - din = null; + close(); } } @@ -103,4 +99,16 @@ public int getRecordLength() { public long getKeyPrefix() { return keyPrefix; } + + @Override + public void close() throws IOException { + if (in != null) { + try { + in.close(); + } finally { + in = null; + din = null; + } + } + } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 99df259b4e8e..4b50fbf59f80 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import com.google.common.collect.Lists; +import com.google.common.io.Closeables; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -121,23 +122,23 @@ public void onStop() { /** Create a socket connection and receive data until receiver is stopped */ private void receive() { - Socket socket = null; - String userInput = null; - try { - // connect to the server - socket = new Socket(host, port); - - BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); - - // Until stopped or connection broken continue reading - while (!isStopped() && (userInput = reader.readLine()) != null) { - System.out.println("Received data '" + userInput + "'"); - store(userInput); + Socket socket = null; + BufferedReader reader = null; + String userInput = null; + try { + // connect to the server + socket = new Socket(host, port); + reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + // Until stopped or connection broken continue reading + while (!isStopped() && (userInput = reader.readLine()) != null) { + System.out.println("Received data '" + userInput + "'"); + store(userInput); + } + } finally { + Closeables.close(reader, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - reader.close(); - socket.close(); - // Restart in an attempt to connect again when server is active again restart("Trying to connect again"); } catch(ConnectException ce) { diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dc5fa1cee69b..50a324e29338 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -31,6 +31,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.google.common.io.Closeables; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -78,10 +79,15 @@ public static void setUp() throws Exception { testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); - byte[] fileContent = new byte[1024]; - new Random().nextBytes(fileContent); - fp.write(fileContent); - fp.close(); + boolean shouldSuppressIOException = true; + try { + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + shouldSuppressIOException = false; + } finally { + Closeables.close(fp, shouldSuppressIOException); + } final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); @@ -117,6 +123,7 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { + bufferChunk.release(); server.close(); clientFactory.close(); testFile.delete(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 3fdde054ab6c..7ac1ca128aed 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.OutputStream; +import com.google.common.io.Closeables; import com.google.common.io.Files; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -60,21 +61,28 @@ public void cleanup() { public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; - OutputStream dataStream = new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); - DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + OutputStream dataStream = null; + DataOutputStream indexStream = null; + boolean suppressExceptionsDuringClose = true; - long offset = 0; - indexStream.writeLong(offset); - for (byte[] block : blocks) { - offset += block.length; - dataStream.write(block); + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + suppressExceptionsDuringClose = false; + } finally { + Closeables.close(dataStream, suppressExceptionsDuringClose); + Closeables.close(indexStream, suppressExceptionsDuringClose); } - - dataStream.close(); - indexStream.close(); } /** Creates reducer blocks in a hash-based data format within our local dirs. */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index ec2bffd6a5b9..7a8ef9d14784 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -23,6 +23,7 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.junit.Assert.*; +import com.google.common.io.Closeables; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -121,14 +122,19 @@ public void onStop() { private void receive() { try { - Socket socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + Socket socket = null; + BufferedReader in = null; + try { + socket = new Socket(host, port); + in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + } finally { + Closeables.close(in, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - in.close(); - socket.close(); } catch(ConnectException ce) { ce.printStackTrace(); restart("Could not connect", ce); From a402c92c92b2e1c85d264f6077aec8f6d6a08270 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Nov 2015 16:08:06 -0800 Subject: [PATCH 1437/1824] [SPARK-11814][STREAMING] Add better default checkpoint duration DStream checkpoint interval is by default set at max(10 second, batch interval). That's bad for large batch intervals where the checkpoint interval = batch interval, and RDDs get checkpointed every batch. This PR is to set the checkpoint interval of trackStateByKey to 10 * batch duration. Author: Tathagata Das Closes #9805 from tdas/SPARK-11814. --- .../streaming/dstream/TrackStateDStream.scala | 13 ++++++ .../streaming/TrackStateByKeySuite.scala | 44 ++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 98e881e6ae11..0ada1111ce30 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -25,6 +25,7 @@ import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} +import org.apache.spark.streaming.dstream.InternalTrackStateDStream._ /** * :: Experimental :: @@ -120,6 +121,14 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Enable automatic checkpointing */ override val mustCheckpoint = true + /** Override the default checkpoint duration */ + override def initialize(time: Time): Unit = { + if (checkpointDuration == null) { + checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER + } + super.initialize(time) + } + /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD @@ -141,3 +150,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT } } } + +private[streaming] object InternalTrackStateDStream { + private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10 +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index e3072b444284..58aef74c0040 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -22,9 +22,10 @@ import java.io.File import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag +import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} @@ -57,6 +58,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef sc = new SparkContext(conf) } + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + test("state - get, exists, update, remove, ") { var state: StateImpl[Int] = null @@ -436,6 +443,41 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) } + test("trackStateByKey - checkpoint durations") { + val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream) + + def testCheckpointDuration( + batchDuration: Duration, + expectedCheckpointDuration: Duration, + explicitCheckpointDuration: Option[Duration] = None + ): Unit = { + try { + ssc = new StreamingContext(sc, batchDuration) + val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) + val dummyFunc = (value: Option[Int], state: State[Int]) => 0 + val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) + val internalTrackStateStream = trackStateStream invokePrivate privateMethod() + + explicitCheckpointDuration.foreach { d => + trackStateStream.checkpoint(d) + } + trackStateStream.register() + ssc.start() // should initialize all the checkpoint durations + assert(trackStateStream.checkpointDuration === null) + assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) + } finally { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + } + } + + testCheckpointDuration(Milliseconds(100), Seconds(1)) + testCheckpointDuration(Seconds(1), Seconds(10)) + testCheckpointDuration(Seconds(10), Seconds(100)) + + testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) + } private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], From 921900fd06362474f8caac675803d526a0986d70 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 18 Nov 2015 16:19:00 -0800 Subject: [PATCH 1438/1824] [SPARK-11791] Fix flaky test in BatchedWriteAheadLogSuite stack trace of failure: ``` org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 62 times over 1.006322071 seconds. Last failure message: Argument(s) are different! Wanted: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.BatchedWriteAheadLogSuite$$anonfun$23$$anonfun$apply$mcV$sp$15.apply(WriteAheadLogSuite.scala:518) Actual invocation has different arguments: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.WriteAheadLogSuite$BlockingWriteAheadLog.write(WriteAheadLogSuite.scala:756) ``` I believe the issue was that due to a race condition, the ordering of the events could be messed up in the final ByteBuffer, therefore the comparison fails. By adding eventually between the requests, we make sure the ordering is preserved. Note that in real life situations, the ordering across threads will not matter. Another solution would be to implement a custom mockito matcher that sorts and then compares the results, but that kind of sounds like overkill to me. Let me know what you think tdas zsxwing Author: Burak Yavuz Closes #9790 from brkyvz/fix-flaky-2. --- .../spark/streaming/util/WriteAheadLogSuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7f80d6ecdbbb..eaa88ea3cd38 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -30,6 +30,7 @@ import scala.language.{implicitConversions, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.ArgumentCaptor import org.mockito.Matchers.{eq => meq} import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -507,15 +508,18 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } blockingWal.allowWrite() - val buffer1 = wrapArrayArrayByte(Array(event1)) - val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) + val buffer = wrapArrayArrayByte(Array(event1)) + val queuedEvents = Set(event2, event3, event4, event5) eventually(timeout(1 second)) { assert(batchedWal.invokePrivate(queueLength()) === 0) - verify(wal, times(1)).write(meq(buffer1), meq(3L)) + verify(wal, times(1)).write(meq(buffer), meq(3L)) // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. - verify(wal, times(1)).write(meq(buffer2), meq(10L)) + val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L)) + val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) + assert(records.toSet === queuedEvents) } } From 59a501359a267fbdb7689058693aa788703e54b1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 18 Nov 2015 16:48:09 -0800 Subject: [PATCH 1439/1824] [SPARK-11636][SQL] Support classes defined in the REPL with Encoders Before this PR there were two things that would blow up if you called `df.as[MyClass]` if `MyClass` was defined in the REPL: - [x] Because `classForName` doesn't work on the munged names returned by `tpe.erasure.typeSymbol.asClass.fullName` - [x] Because we don't have anything to pass into the constructor for the `$outer` pointer. Note that this PR is just adding the infrastructure for working with inner classes in encoder and is not yet sufficient to make them work in the REPL. Currently, the implementation show in https://github.com/marmbrus/spark/commit/95cec7d413b930b36420724fafd829bef8c732ab is causing a bug that breaks code gen due to some interaction between janino and the `ExecutorClassLoader`. This will be addressed in a follow-up PR. Author: Michael Armbrust Closes #9602 from marmbrus/dataset-replClasses. --- .../spark/sql/catalyst/ScalaReflection.scala | 81 ++++++++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 26 +++++- .../sql/catalyst/encoders/OuterScopes.scala | 42 ++++++++++ .../catalyst/encoders/ProductEncoder.scala | 6 +- .../expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateProjection.scala | 10 +-- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../codegen/GenerateUnsafeRowJoiner.scala | 6 +- .../sql/catalyst/expressions/literals.scala | 6 ++ .../sql/catalyst/expressions/objects.scala | 42 ++++++++-- .../encoders/ExpressionEncoderSuite.scala | 7 +- .../encoders/ProductEncoderSuite.scala | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../org/apache/spark/sql/GroupedDataset.scala | 8 +- .../aggregate/TypedAggregateExpression.scala | 19 ++--- 17 files changed, 193 insertions(+), 82 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 38828e59a215..59ccf356f2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -35,17 +35,6 @@ object ScalaReflection extends ScalaReflection { // class loader of the current thread. override def mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) -} - -/** - * Support for generating catalyst schemas for scala objects. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror import universe._ @@ -53,30 +42,6 @@ trait ScalaReflection { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - case class Schema(dataType: DataType, nullable: Boolean) - - /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.toAttributes - } - - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } - - /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). - * - * @see SPARK-5281 - */ - def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe - /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including @@ -114,7 +79,9 @@ trait ScalaReflection { } ObjectType(cls) - case other => ObjectType(Utils.classForName(className)) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) } } @@ -640,6 +607,48 @@ trait ScalaReflection { } } } +} + +/** + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + case class Schema(dataType: DataType, nullable: Boolean) + + /** Returns a Sequence of attributes for the given case class type. */ + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case Schema(s: StructType, _) => + s.toAttributes + } + + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = + ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b977f278c5b5..456b59500847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.encoders +import java.util.concurrent.ConcurrentMap + import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.util.Utils -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ @@ -211,7 +213,9 @@ case class ExpressionEncoder[T]( * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the * given schema. */ - def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { val positionToAttribute = AttributeMap.toIndex(schema) val unbound = fromRowExpression transform { case b: BoundReference => positionToAttribute(b.ordinal) @@ -219,7 +223,23 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(fromRowExpression = analyzedPlan.expressions.head.children.head) + + // In order to construct instances of inner classes (for example those declared in a REPL cell), + // we need an instance of the outer scope. This rule substitues those outer objects into + // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` + // registry. + copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + + s"to the scope that this class was defined in. " + "" + + "Try moving this class out of its parent class.") + } + + n.copy(outerPointer = Some(Literal.fromObject(outer))) + }) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala new file mode 100644 index 000000000000..a753b187bcd3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker + +object OuterScopes { + @transient + lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + /** + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 55c4ee11b20f..2914c6ee790c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -31,6 +31,7 @@ import scala.reflect.ClassTag object ProductEncoder { import ScalaReflection.universe._ + import ScalaReflection.mirror import ScalaReflection.localTypeOf import ScalaReflection.dataTypeFor import ScalaReflection.Schema @@ -420,8 +421,7 @@ object ProductEncoder { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -429,7 +429,7 @@ object ProductEncoder { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index d51a8dede7f3..a31574c251af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -34,7 +34,7 @@ trait CodegenFallback extends Expression { val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4b66069b5f55..40189f087776 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -82,7 +82,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificMutableProjection(expr); } @@ -109,7 +109,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allProjections // copy all the results into MutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c0d313b2e130..f229f2000d8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -167,7 +167,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { ${initMutableStates(ctx)} } - public Object apply(Object r) { + public java.lang.Object apply(java.lang.Object r) { // GenerateProjection does not work with UnsafeRows. assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); @@ -186,14 +186,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object genericGet(int i) { + public java.lang.Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases } return null; } - public void update(int i, Object value) { + public void update(int i, java.lang.Object value) { if (value == null) { setNullAt(i); return; @@ -212,7 +212,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return result; } - public boolean equals(Object other) { + public boolean equals(java.lang.Object other) { if (other instanceof SpecificRow) { SpecificRow row = (SpecificRow) other; $columnChecks @@ -222,7 +222,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; + java.lang.Object[] arr = new java.lang.Object[${expressions.length}]; ${copyColumns} return new ${classOf[GenericInternalRow].getName}(arr); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f0ed8645d923..b7926bda3de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -148,7 +148,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); } @@ -165,7 +165,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions return mutableRow; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4c17d02a2372..7b6c9373ebe3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -324,7 +324,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" - public Object generate($exprType[] exprs) { + public java.lang.Object generate($exprType[] exprs) { return new SpecificUnsafeProjection(exprs); } @@ -342,7 +342,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } // Scala.Function1 need this - public Object apply(Object row) { + public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da91ff29537b..da602d9b4bce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -159,7 +159,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // ------------------------ Finally, put everything together --------------------------- // val code = s""" - |public Object generate($exprType[] exprs) { + |public java.lang.Object generate($exprType[] exprs) { | return new SpecificUnsafeRowJoiner(); |} | @@ -176,9 +176,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | buf = new byte[sizeInBytes]; | } | - | final Object obj1 = row1.getBaseObject(); + | final java.lang.Object obj1 = row1.getBaseObject(); | final long offset1 = row1.getBaseOffset(); - | final Object obj2 = row2.getBaseObject(); + | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | | $copyBitset diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 455fa2427c26..e34fd49be838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -48,6 +48,12 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object + * into code generation. + */ + def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index acf0da240051..f865a9408ef4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -178,6 +179,15 @@ case class Invoke( } } +object NewInstance { + def apply( + cls: Class[_], + arguments: Seq[Expression], + propagateNull: Boolean = false, + dataType: DataType): NewInstance = + new NewInstance(cls, arguments, propagateNull, dataType, None) +} + /** * Constructs a new instance of the given class, using the result of evaluating the specified * expressions as arguments. @@ -189,12 +199,15 @@ case class Invoke( * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you * to manually specify the type when the object in question is a valid internal * representation (i.e. ArrayData) instead of an object. + * @param outerPointer If the object being constructed is an inner class the outerPointer must + * for the containing class must be specified. */ case class NewInstance( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = true, - dataType: DataType) extends Expression { + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[Literal]) extends Expression { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -209,30 +222,43 @@ case class NewInstance( val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") + val outer = outerPointer.map(_.gen(ctx)) + + val setup = + s""" + ${argGen.map(_.code).mkString("\n")} + ${outer.map(_.code.mkString("")).getOrElse("")} + """.stripMargin + + val constructorCall = outer.map { gen => + s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + }.getOrElse { + s"new $className($argString)" + } + if (propagateNull) { val objNullCheck = if (ctx.defaultValue(dataType) == "null") { s"${ev.isNull} = ${ev.value} == null;" } else { "" } - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" - ${argGen.map(_.code).mkString("\n")} + $setup boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if ($argsNonNull) { - ${ev.value} = new $className($argString); + ${ev.value} = $constructorCall; ${ev.isNull} = false; } """ } else { s""" - ${argGen.map(_.code).mkString("\n")} + $setup - $javaType ${ev.value} = new $className($argString); + $javaType ${ev.value} = $constructorCall; final boolean ${ev.isNull} = ${ev.value} == null; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 9fe64b4cf10e..cde0364f3dd9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.catalyst.encoders import java.util.Arrays +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.AttributeReference @@ -25,6 +28,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.ArrayType abstract class ExpressionEncoderSuite extends SparkFunSuite { + val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + protected def encodeDecodeTest[T]( input: T, encoder: ExpressionEncoder[T], @@ -32,7 +37,7 @@ abstract class ExpressionEncoderSuite extends SparkFunSuite { test(s"encode/decode for $testName: $input") { val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema).bind(schema) + val boundEncoder = encoder.resolve(schema, outers).bind(schema) val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index bc539d62c537..1798514c5c38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -53,6 +53,10 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) class ProductEncoderSuite extends ExpressionEncoderSuite { + outers.put(getClass.getName, this) + + case class InnerClass(i: Int) + productTest(InnerClass(1)) productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b644f6ad3096..bdcdc5d47cba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,7 +74,7 @@ class Dataset[T] private[sql]( /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output) + unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) private implicit def classTag = resolvedTEncoder.clsTag @@ -375,7 +375,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder, + resolvedTEncoder.bind(queryExecution.analyzed.output), queryExecution.analyzed.output).named :: Nil, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 3f84e22a1025..7e5acbe8517d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql]( private implicit val unresolvedKEncoder = encoderFor(kEncoder) private implicit val unresolvedTEncoder = encoderFor(tEncoder) - private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) - private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) + private val resolvedTEncoder = + unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 3f2775896bb8..6ce41aaf01e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -52,8 +52,8 @@ object TypedAggregateExpression { */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], - bEncoder: ExpressionEncoder[Any], + aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. + bEncoder: ExpressionEncoder[Any], // Should be bound. cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -92,9 +92,6 @@ case class TypedAggregateExpression( // We let the dataset do the binding for us. lazy val boundA = aEncoder.get - val bAttributes = bEncoder.schema.toAttributes - lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { // todo: need a more neat way to assign the value. var i = 0 @@ -114,24 +111,24 @@ case class TypedAggregateExpression( override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) val merged = aggregator.merge(b1, b2) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val result = cEncoder.toRow(aggregator.finish(b)) dataType match { case _: StructType => result From e99d3392068bc929c900a4cc7b50e9e2b437a23a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 18:34:01 -0800 Subject: [PATCH 1440/1824] [SPARK-11839][ML] refactor save/write traits * add "ML" prefix to reader/writer/readable/writable to avoid name collision with java.util.* * define `DefaultParamsReadable/Writable` and use them to save some code * use `super.load` instead so people can jump directly to the doc of `Readable.load`, which documents the Java compatibility issues jkbradley Author: Xiangrui Meng Closes #9827 from mengxr/SPARK-11839. --- .../scala/org/apache/spark/ml/Pipeline.scala | 40 +++++++++---------- .../classification/LogisticRegression.scala | 29 +++++++------- .../apache/spark/ml/feature/Binarizer.scala | 12 ++---- .../apache/spark/ml/feature/Bucketizer.scala | 12 ++---- .../spark/ml/feature/CountVectorizer.scala | 22 ++++------ .../org/apache/spark/ml/feature/DCT.scala | 12 ++---- .../apache/spark/ml/feature/HashingTF.scala | 12 ++---- .../org/apache/spark/ml/feature/IDF.scala | 23 +++++------ .../apache/spark/ml/feature/Interaction.scala | 12 ++---- .../spark/ml/feature/MinMaxScaler.scala | 22 ++++------ .../org/apache/spark/ml/feature/NGram.scala | 12 ++---- .../apache/spark/ml/feature/Normalizer.scala | 12 ++---- .../spark/ml/feature/OneHotEncoder.scala | 12 ++---- .../ml/feature/PolynomialExpansion.scala | 12 ++---- .../ml/feature/QuantileDiscretizer.scala | 12 ++---- .../spark/ml/feature/SQLTransformer.scala | 13 ++---- .../spark/ml/feature/StandardScaler.scala | 22 ++++------ .../spark/ml/feature/StopWordsRemover.scala | 12 ++---- .../spark/ml/feature/StringIndexer.scala | 32 +++++---------- .../apache/spark/ml/feature/Tokenizer.scala | 24 +++-------- .../spark/ml/feature/VectorAssembler.scala | 12 ++---- .../spark/ml/feature/VectorSlicer.scala | 12 ++---- .../apache/spark/ml/recommendation/ALS.scala | 27 +++++-------- .../ml/regression/LinearRegression.scala | 30 ++++++-------- .../org/apache/spark/ml/util/ReadWrite.scala | 40 ++++++++++++------- .../org/apache/spark/ml/PipelineSuite.scala | 14 +++---- .../spark/ml/util/DefaultReadWriteTest.scala | 17 ++++---- 27 files changed, 190 insertions(+), 321 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 25f0c696f42b..b0f22e042ec5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -29,8 +29,8 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Reader -import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util.MLReader +import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -89,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { def this() = this(Identifiable.randomUID("pipeline")) @@ -174,16 +174,16 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } - override def write: Writer = new Pipeline.PipelineWriter(this) + override def write: MLWriter = new Pipeline.PipelineWriter(this) } -object Pipeline extends Readable[Pipeline] { +object Pipeline extends MLReadable[Pipeline] { - override def read: Reader[Pipeline] = new PipelineReader + override def read: MLReader[Pipeline] = new PipelineReader - override def load(path: String): Pipeline = read.load(path) + override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,7 +191,7 @@ object Pipeline extends Readable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends Reader[Pipeline] { + private[ml] class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.Pipeline" @@ -202,7 +202,7 @@ object Pipeline extends Readable[Pipeline] { } } - /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */ private[ml] object SharedReadWrite { import org.json4s.JsonDSL._ @@ -210,7 +210,7 @@ object Pipeline extends Readable[Pipeline] { /** Check that all stages are Writable */ def validateStages(stages: Array[PipelineStage]): Unit = { stages.foreach { - case stage: Writable => // good + case stage: MLWritable => // good case other => throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + @@ -245,7 +245,7 @@ object Pipeline extends Readable[Pipeline] { // Save stages val stagesDir = new Path(path, "stages").toString - stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stages.zipWithIndex.foreach { case (stage: MLWritable, idx: Int) => stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) } } @@ -285,7 +285,7 @@ object Pipeline extends Readable[Pipeline] { val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) } (metadata.uid, stages) } @@ -308,7 +308,7 @@ object Pipeline extends Readable[Pipeline] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Writable with Logging { + extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -333,18 +333,18 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } - override def write: Writer = new PipelineModel.PipelineModelWriter(this) + override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } -object PipelineModel extends Readable[PipelineModel] { +object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite - override def read: Reader[PipelineModel] = new PipelineModelReader + override def read: MLReader[PipelineModel] = new PipelineModelReader - override def load(path: String): PipelineModel = read.load(path) + override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,7 +352,7 @@ object PipelineModel extends Readable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends Reader[PipelineModel] { + private[ml] class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.PipelineModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71c2533bcbf4..a3cc49f7f018 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,9 +29,9 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas @Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Writable with Logging { + with LogisticRegressionParams with DefaultParamsWritable with Logging { def this() = this(Identifiable.randomUID("logreg")) @@ -385,12 +385,11 @@ class LogisticRegression(override val uid: String) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) - - override def write: Writer = new DefaultParamsWriter(this) } -object LogisticRegression extends Readable[LogisticRegression] { - override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression] +object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + + override def load(path: String): LogisticRegression = super.load(path) } /** @@ -403,7 +402,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with Writable { + with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -519,26 +518,26 @@ class LogisticRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * * This also does not save the [[parent]] currently. */ - override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } -object LogisticRegressionModel extends Readable[LogisticRegressionModel] { +object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { - override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader + override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader - override def load(path: String): LogisticRegressionModel = read.load(path) + override def load(path: String): LogisticRegressionModel = super.load(path) - /** [[Writer]] instance for [[LogisticRegressionModel]] */ + /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data( numClasses: Int, @@ -558,7 +557,7 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] { } private[classification] class LogisticRegressionModelReader - extends Reader[LogisticRegressionModel] { + extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index e2be6547d8f0..63c06581482e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with Writable with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,17 +86,11 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Binarizer extends Readable[Binarizer] { - - @Since("1.6.0") - override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] +object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") - override def load(path: String): Binarizer = read.load(path) + override def load(path: String): Binarizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 7095fbd70aa0..324353a96afb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -93,12 +93,9 @@ final class Bucketizer(override val uid: String) override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } -object Bucketizer extends Readable[Bucketizer] { +object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** We require splits to be of length >= 3 and to be in strictly increasing order. */ private[feature] def checkSplits(splits: Array[Double]): Boolean = { @@ -140,8 +137,5 @@ object Bucketizer extends Readable[Bucketizer] { } @Since("1.6.0") - override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer] - - @Since("1.6.0") - override def load(path: String): Bucketizer = read.load(path) + override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 5ff9bfb7d111..4969cf42450d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -107,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("cntVec")) @@ -171,16 +171,10 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object CountVectorizer extends Readable[CountVectorizer] { - - @Since("1.6.0") - override def read: Reader[CountVectorizer] = new DefaultParamsReader +object CountVectorizer extends DefaultParamsReadable[CountVectorizer] { @Since("1.6.0") override def load(path: String): CountVectorizer = super.load(path) @@ -193,7 +187,7 @@ object CountVectorizer extends Readable[CountVectorizer] { */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable { import CountVectorizerModel._ @@ -251,14 +245,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } @Since("1.6.0") - override def write: Writer = new CountVectorizerModelWriter(this) + override def write: MLWriter = new CountVectorizerModelWriter(this) } @Since("1.6.0") -object CountVectorizerModel extends Readable[CountVectorizerModel] { +object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private[CountVectorizerModel] - class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter { private case class Data(vocabulary: Seq[String]) @@ -270,7 +264,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } } - private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { private val className = "org.apache.spark.ml.feature.CountVectorizerModel" @@ -288,7 +282,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } @Since("1.6.0") - override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader @Since("1.6.0") override def load(path: String): CountVectorizerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 6ea5a616173e..6bed72164a1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] with Writable { + extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("dct")) @@ -69,17 +69,11 @@ class DCT(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object DCT extends Readable[DCT] { - - @Since("1.6.0") - override def read: Reader[DCT] = new DefaultParamsReader[DCT] +object DCT extends DefaultParamsReadable[DCT] { @Since("1.6.0") - override def load(path: String): DCT = read.load(path) + override def load(path: String): DCT = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 6d2ea675f561..9e15835429a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{ArrayType, StructType} */ @Experimental class HashingTF(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -77,17 +77,11 @@ class HashingTF(override val uid: String) } override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object HashingTF extends Readable[HashingTF] { - - @Since("1.6.0") - override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF] +object HashingTF extends DefaultParamsReadable[HashingTF] { @Since("1.6.0") - override def load(path: String): HashingTF = read.load(path) + override def load(path: String): HashingTF = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 53ad34ef1264..0e00ef6f2ee2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -62,7 +62,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idf")) @@ -87,16 +88,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IDF extends Readable[IDF] { - - @Since("1.6.0") - override def read: Reader[IDF] = new DefaultParamsReader +object IDF extends DefaultParamsReadable[IDF] { @Since("1.6.0") override def load(path: String): IDF = super.load(path) @@ -110,7 +105,7 @@ object IDF extends Readable[IDF] { class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase with Writable { + extends Model[IDFModel] with IDFBase with MLWritable { import IDFModel._ @@ -140,13 +135,13 @@ class IDFModel private[ml] ( def idf: Vector = idfModel.idf @Since("1.6.0") - override def write: Writer = new IDFModelWriter(this) + override def write: MLWriter = new IDFModelWriter(this) } @Since("1.6.0") -object IDFModel extends Readable[IDFModel] { +object IDFModel extends MLReadable[IDFModel] { - private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { private case class Data(idf: Vector) @@ -158,7 +153,7 @@ object IDFModel extends Readable[IDFModel] { } } - private class IDFModelReader extends Reader[IDFModel] { + private class IDFModelReader extends MLReader[IDFModel] { private val className = "org.apache.spark.ml.feature.IDFModel" @@ -176,7 +171,7 @@ object IDFModel extends Readable[IDFModel] { } @Since("1.6.0") - override def read: Reader[IDFModel] = new IDFModelReader + override def read: MLReader[IDFModel] = new IDFModelReader @Since("1.6.0") override def load(path: String): IDFModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9df6b311cc9d..2181119f04a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.types._ @Since("1.6.0") @Experimental class Interaction @Since("1.6.0") (override val uid: String) extends Transformer - with HasInputCols with HasOutputCol with Writable { + with HasInputCols with HasOutputCol with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("interaction")) @@ -224,19 +224,13 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer require($(inputCols).length > 0, "Input cols must have non-zero length.") require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Interaction extends Readable[Interaction] { - - @Since("1.6.0") - override def read: Reader[Interaction] = new DefaultParamsReader[Interaction] +object Interaction extends DefaultParamsReadable[Interaction] { @Since("1.6.0") - override def load(path: String): Interaction = read.load(path) + override def load(path: String): Interaction = super.load(path) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 24d964fae834..ed24eabb5044 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -88,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -118,16 +118,10 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object MinMaxScaler extends Readable[MinMaxScaler] { - - @Since("1.6.0") - override def read: Reader[MinMaxScaler] = new DefaultParamsReader +object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { @Since("1.6.0") override def load(path: String): MinMaxScaler = super.load(path) @@ -147,7 +141,7 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable { import MinMaxScalerModel._ @@ -195,14 +189,14 @@ class MinMaxScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new MinMaxScalerModelWriter(this) + override def write: MLWriter = new MinMaxScalerModelWriter(this) } @Since("1.6.0") -object MinMaxScalerModel extends Readable[MinMaxScalerModel] { +object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private[MinMaxScalerModel] - class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter { private case class Data(originalMin: Vector, originalMax: Vector) @@ -214,7 +208,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } } - private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" @@ -231,7 +225,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } @Since("1.6.0") - override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + override def read: MLReader[MinMaxScalerModel] = new MinMaxScalerModelReader @Since("1.6.0") override def load(path: String): MinMaxScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 4a17acd95199..65414ecbefbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("ngram")) @@ -66,17 +66,11 @@ class NGram(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object NGram extends Readable[NGram] { - - @Since("1.6.0") - override def read: Reader[NGram] = new DefaultParamsReader[NGram] +object NGram extends DefaultParamsReadable[NGram] { @Since("1.6.0") - override def load(path: String): NGram = read.load(path) + override def load(path: String): NGram = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 9df6a091d505..c2d514fd9629 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class Normalizer(override val uid: String) - extends UnaryTransformer[Vector, Vector, Normalizer] with Writable { + extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("normalizer")) @@ -56,17 +56,11 @@ class Normalizer(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Normalizer extends Readable[Normalizer] { - - @Since("1.6.0") - override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer] +object Normalizer extends DefaultParamsReadable[Normalizer] { @Since("1.6.0") - override def load(path: String): Normalizer = read.load(path) + override def load(path: String): Normalizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 4e2adfaafa21..d70164eaf022 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol with Writable { + with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("oneHot")) @@ -165,17 +165,11 @@ class OneHotEncoder(override val uid: String) extends Transformer } override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object OneHotEncoder extends Readable[OneHotEncoder] { - - @Since("1.6.0") - override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder] +object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { @Since("1.6.0") - override def load(path: String): OneHotEncoder = read.load(path) + override def load(path: String): OneHotEncoder = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 49415398325f..08610593fadd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("poly")) @@ -63,9 +63,6 @@ class PolynomialExpansion(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } /** @@ -81,7 +78,7 @@ class PolynomialExpansion(override val uid: String) * current index and increment it properly for sparse input. */ @Since("1.6.0") -object PolynomialExpansion extends Readable[PolynomialExpansion] { +object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -182,8 +179,5 @@ object PolynomialExpansion extends Readable[PolynomialExpansion] { } @Since("1.6.0") - override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion] - - @Since("1.6.0") - override def load(path: String): PolynomialExpansion = read.load(path) + override def load(path: String): PolynomialExpansion = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 2da5c966d296..7bf67c6325a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w */ @Experimental final class QuantileDiscretizer(override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -93,13 +93,10 @@ final class QuantileDiscretizer(override val uid: String) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { +object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { /** * Sampling from the given dataset to collect quantile statistics. */ @@ -179,8 +176,5 @@ object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { } @Since("1.6.0") - override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer] - - @Since("1.6.0") - override def load(path: String): QuantileDiscretizer = read.load(path) + override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c115064ff301..3a735017ba83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.StructType */ @Experimental @Since("1.6.0") -class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable { +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer + with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) @@ -77,17 +78,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object SQLTransformer extends Readable[SQLTransformer] { - - @Since("1.6.0") - override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer] +object SQLTransformer extends DefaultParamsReadable[SQLTransformer] { @Since("1.6.0") - override def load(path: String): SQLTransformer = read.load(path) + override def load(path: String): SQLTransformer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index ab04e5418dd4..1f689c1da1ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -59,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams with Writable { + with StandardScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stdScal")) @@ -96,16 +96,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StandardScaler extends Readable[StandardScaler] { - - @Since("1.6.0") - override def read: Reader[StandardScaler] = new DefaultParamsReader +object StandardScaler extends DefaultParamsReadable[StandardScaler] { @Since("1.6.0") override def load(path: String): StandardScaler = super.load(path) @@ -119,7 +113,7 @@ object StandardScaler extends Readable[StandardScaler] { class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams with Writable { + extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ @@ -165,14 +159,14 @@ class StandardScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new StandardScalerModelWriter(this) + override def write: MLWriter = new StandardScalerModelWriter(this) } @Since("1.6.0") -object StandardScalerModel extends Readable[StandardScalerModel] { +object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] - class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) @@ -184,7 +178,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } } - private class StandardScalerModelReader extends Reader[StandardScalerModel] { + private class StandardScalerModelReader extends MLReader[StandardScalerModel] { private val className = "org.apache.spark.ml.feature.StandardScalerModel" @@ -204,7 +198,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } @Since("1.6.0") - override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader @Since("1.6.0") override def load(path: String): StandardScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index f1146988dcc7..318808596dc6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -86,7 +86,7 @@ private[spark] object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stopWords")) @@ -154,17 +154,11 @@ class StopWordsRemover(override val uid: String) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StopWordsRemover extends Readable[StopWordsRemover] { - - @Since("1.6.0") - override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover] +object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { @Since("1.6.0") - override def load(path: String): StopWordsRemover = read.load(path) + override def load(path: String): StopWordsRemover = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f16f6afc002d..97a2e4f6d6ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -65,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase with Writable { + with StringIndexerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("strIdx")) @@ -93,16 +93,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StringIndexer extends Readable[StringIndexer] { - - @Since("1.6.0") - override def read: Reader[StringIndexer] = new DefaultParamsReader +object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -122,7 +116,7 @@ object StringIndexer extends Readable[StringIndexer] { class StringIndexerModel ( override val uid: String, val labels: Array[String]) - extends Model[StringIndexerModel] with StringIndexerBase with Writable { + extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @@ -199,10 +193,10 @@ class StringIndexerModel ( } @Since("1.6.0") -object StringIndexerModel extends Readable[StringIndexerModel] { +object StringIndexerModel extends MLReadable[StringIndexerModel] { private[StringIndexerModel] - class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { private case class Data(labels: Array[String]) @@ -214,7 +208,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } } - private class StringIndexerModelReader extends Reader[StringIndexerModel] { + private class StringIndexerModelReader extends MLReader[StringIndexerModel] { private val className = "org.apache.spark.ml.feature.StringIndexerModel" @@ -232,7 +226,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } @Since("1.6.0") - override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + override def read: MLReader[StringIndexerModel] = new StringIndexerModelReader @Since("1.6.0") override def load(path: String): StringIndexerModel = super.load(path) @@ -249,7 +243,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { */ @Experimental class IndexToString private[ml] (override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idxToStr")) @@ -316,17 +310,11 @@ class IndexToString private[ml] (override val uid: String) override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IndexToString extends Readable[IndexToString] { - - @Since("1.6.0") - override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString] +object IndexToString extends DefaultParamsReadable[IndexToString] { @Since("1.6.0") - override def load(path: String): IndexToString = read.load(path) + override def load(path: String): IndexToString = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0e4445d1e2fa..8ad7bbedaab5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class Tokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("tok")) @@ -46,19 +46,13 @@ class Tokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Tokenizer extends Readable[Tokenizer] { - - @Since("1.6.0") - override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer] +object Tokenizer extends DefaultParamsReadable[Tokenizer] { @Since("1.6.0") - override def load(path: String): Tokenizer = read.load(path) + override def load(path: String): Tokenizer = super.load(path) } /** @@ -70,7 +64,7 @@ object Tokenizer extends Readable[Tokenizer] { */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("regexTok")) @@ -145,17 +139,11 @@ class RegexTokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object RegexTokenizer extends Readable[RegexTokenizer] { - - @Since("1.6.0") - override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer] +object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] { @Since("1.6.0") - override def load(path: String): RegexTokenizer = read.load(path) + override def load(path: String): RegexTokenizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7e54205292ca..0feec0549852 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with Writable { + extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -120,19 +120,13 @@ class VectorAssembler(override val uid: String) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorAssembler extends Readable[VectorAssembler] { - - @Since("1.6.0") - override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler] +object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { @Since("1.6.0") - override def load(path: String): VectorAssembler = read.load(path) + override def load(path: String): VectorAssembler = super.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 911582b55b57..5410a50bc2e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ @Experimental final class VectorSlicer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vectorSlicer")) @@ -151,13 +151,10 @@ final class VectorSlicer(override val uid: String) } override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorSlicer extends Readable[VectorSlicer] { +object VectorSlicer extends DefaultParamsReadable[VectorSlicer] { /** Return true if given feature indices are valid */ private[feature] def validIndices(indices: Array[Int]): Boolean = { @@ -174,8 +171,5 @@ object VectorSlicer extends Readable[VectorSlicer] { } @Since("1.6.0") - override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer] - - @Since("1.6.0") - override def load(path: String): VectorSlicer = read.load(path) + override def load(path: String): VectorSlicer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d92514d2e239..795b73c4c212 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -185,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams with Writable { + extends Model[ALSModel] with ALSModelParams with MLWritable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -225,19 +225,19 @@ class ALSModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new ALSModel.ALSModelWriter(this) + override def write: MLWriter = new ALSModel.ALSModelWriter(this) } @Since("1.6.0") -object ALSModel extends Readable[ALSModel] { +object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") - override def read: Reader[ALSModel] = new ALSModelReader + override def read: MLReader[ALSModel] = new ALSModelReader @Since("1.6.0") - override def load(path: String): ALSModel = read.load(path) + override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,7 +249,7 @@ object ALSModel extends Readable[ALSModel] { } } - private[recommendation] class ALSModelReader extends Reader[ALSModel] { + private[recommendation] class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.recommendation.ALSModel" @@ -309,7 +309,8 @@ object ALSModel extends Readable[ALSModel] { * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams + with DefaultParamsWritable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -391,9 +392,6 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w } override def copy(extra: ParamMap): ALS = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @@ -406,7 +404,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w * than 2 billion. */ @DeveloperApi -object ALS extends Readable[ALS] with Logging { +object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: @@ -416,10 +414,7 @@ object ALS extends Readable[ALS] with Logging { case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) @Since("1.6.0") - override def read: Reader[ALS] = new DefaultParamsReader[ALS] - - @Since("1.6.0") - override def load(path: String): ALS = read.load(path) + override def load(path: String): ALS = super.load(path) /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7c44f0a51b8..7ba1a60edaf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -66,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Writable with Logging { + with LinearRegressionParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -345,19 +345,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object LinearRegression extends Readable[LinearRegression] { - - @Since("1.6.0") - override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] +object LinearRegression extends DefaultParamsReadable[LinearRegression] { @Since("1.6.0") - override def load(path: String): LinearRegression = read.load(path) + override def load(path: String): LinearRegression = super.load(path) } /** @@ -371,7 +365,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with Writable { + with LinearRegressionParams with MLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -441,7 +435,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -449,21 +443,21 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) } @Since("1.6.0") -object LinearRegressionModel extends Readable[LinearRegressionModel] { +object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") - override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + override def read: MLReader[LinearRegressionModel] = new LinearRegressionModelReader @Since("1.6.0") - override def load(path: String): LinearRegressionModel = read.load(path) + override def load(path: String): LinearRegressionModel = super.load(path) - /** [[Writer]] instance for [[LinearRegressionModel]] */ + /** [[MLWriter]] instance for [[LinearRegressionModel]] */ private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data(intercept: Double, coefficients: Vector) @@ -477,7 +471,7 @@ object LinearRegressionModel extends Readable[LinearRegressionModel] { } } - private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.regression.LinearRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d8ce907af532..ff9322dba122 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** - * Trait for [[Writer]] and [[Reader]]. + * Trait for [[MLWriter]] and [[MLReader]]. */ private[util] sealed trait BaseReadWrite { private var optionSQLContext: Option[SQLContext] = None @@ -64,7 +64,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite with Logging { +abstract class MLWriter extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -111,16 +111,16 @@ abstract class Writer extends BaseReadWrite with Logging { } /** - * Trait for classes that provide [[Writer]]. + * Trait for classes that provide [[MLWriter]]. */ @Since("1.6.0") -trait Writable { +trait MLWritable { /** - * Returns a [[Writer]] instance for this ML instance. + * Returns an [[MLWriter]] instance for this ML instance. */ @Since("1.6.0") - def write: Writer + def write: MLWriter /** * Saves this ML instance to the input path, a shortcut of `write.save(path)`. @@ -130,13 +130,18 @@ trait Writable { def save(path: String): Unit = write.save(path) } +private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => + + override def write: MLWriter = new DefaultParamsWriter(this) +} + /** * Abstract class for utility classes that can load ML instances. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -abstract class Reader[T] extends BaseReadWrite { +abstract class MLReader[T] extends BaseReadWrite { /** * Loads the ML component from the input path. @@ -149,18 +154,18 @@ abstract class Reader[T] extends BaseReadWrite { } /** - * Trait for objects that provide [[Reader]]. + * Trait for objects that provide [[MLReader]]. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -trait Readable[T] { +trait MLReadable[T] { /** - * Returns a [[Reader]] instance for this class. + * Returns an [[MLReader]] instance for this class. */ @Since("1.6.0") - def read: Reader[T] + def read: MLReader[T] /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. @@ -171,13 +176,18 @@ trait Readable[T] { def load(path: String): T = read.load(path) } +private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader +} + /** - * Default [[Writer]] implementation for transformers and estimators that contain basic + * Default [[MLWriter]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer { +private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) @@ -218,13 +228,13 @@ private[ml] object DefaultParamsWriter { } /** - * Default [[Reader]] implementation for transformers and estimators that contain basic + * Default [[MLReader]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @tparam T ML instance type * TODO: Consider adding check for correct class name. */ -private[ml] class DefaultParamsReader[T] extends Reader[T] { +private[ml] class DefaultParamsReader[T] extends MLReader[T] { override def load(path: String): T = { val metadata = DefaultParamsReader.loadMetadata(path, sc) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 7f5c3895acb0..12aba6bc6dbe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -179,8 +179,8 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[Writable]] stages */ -class WritableStage(override val uid: String) extends Transformer with Writable { +/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -192,21 +192,21 @@ class WritableStage(override val uid: String) extends Transformer with Writable override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) override def transform(dataset: DataFrame): DataFrame = dataset override def transformSchema(schema: StructType): StructType = schema } -object WritableStage extends Readable[WritableStage] { +object WritableStage extends MLReadable[WritableStage] { - override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage] + override def read: MLReader[WritableStage] = new DefaultParamsReader[WritableStage] - override def load(path: String): WritableStage = read.load(path) + override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[Writable]] stages */ +/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index dd1e8acce941..84d06b43d622 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -38,7 +38,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable]( + def testDefaultReadWrite[T <: Params with MLWritable]( instance: T, testParams: Boolean = true): T = { val uid = instance.uid @@ -52,7 +52,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance.save(path) } instance.write.overwrite().save(path) - val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) @@ -92,7 +92,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam E Type of [[Estimator]] * @tparam M Type of [[Model]] produced by estimator */ - def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable]( + def testEstimatorAndModelReadWrite[ + E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: DataFrame, testParams: Map[String, Any], @@ -119,7 +120,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } } -class MyParams(override val uid: String) extends Params with Writable { +class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -145,14 +146,14 @@ class MyParams(override val uid: String) extends Params with Writable { override def copy(extra: ParamMap): Params = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) } -object MyParams extends Readable[MyParams] { +object MyParams extends MLReadable[MyParams] { - override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] + override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams] - override def load(path: String): MyParams = read.load(path) + override def load(path: String): MyParams = super.load(path) } class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext From e61367b9f9bfc8e123369d55d7ca5925568b98a7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 18:34:36 -0800 Subject: [PATCH 1441/1824] [SPARK-11833][SQL] Add Java tests for Kryo/Java Dataset encoders Also added some nicer error messages for incompatible types (private types and primitive types) for Kryo/Java encoder. Author: Reynold Xin Closes #9823 from rxin/SPARK-11833. --- .../scala/org/apache/spark/sql/Encoder.scala | 69 +++++++++++------ .../encoders/EncoderErrorMessageSuite.scala | 40 ++++++++++ .../catalyst/encoders/FlatEncoderSuite.scala | 22 ++---- .../apache/spark/sql/JavaDatasetSuite.java | 75 ++++++++++++++++++- 4 files changed, 166 insertions(+), 40 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 1ed5111440c8..d54f2854fb33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.lang.reflect.Modifier + import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} @@ -43,30 +45,28 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** A way to construct encoders using generic serializers. */ - private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { - ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - toRowExpressions = Seq( - EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - fromRowExpression = - DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), - clsTag = classTag[T] - ) - } + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) @@ -75,6 +75,8 @@ object Encoders { * serialization. This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -83,17 +85,40 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw new UnsupportedOperationException( + s"${classTag[T].runtimeClass.getName} is not a public class. " + + "Only public classes are supported.") + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw new UnsupportedOperationException("Primitive types are not supported.") + } + + validatePublicClass[T]() + + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + clsTag = classTag[T] + ) + } def tuple[T1, T2]( e1: Encoder[T1], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala new file mode 100644 index 000000000000..0b2a10bb04c1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders + + +class EncoderErrorMessageSuite extends SparkFunSuite { + + // Note: we also test error messages for encoders for private classes in JavaDatasetSuite. + // That is done in Java because Scala cannot create truly private classes. + + test("primitive types in encoders using Kryo serialization") { + intercept[UnsupportedOperationException] { Encoders.kryo[Int] } + intercept[UnsupportedOperationException] { Encoders.kryo[Long] } + intercept[UnsupportedOperationException] { Encoders.kryo[Char] } + } + + test("primitive types in encoders using Java serialization") { + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Int] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 6e0322fb6e01..07523d49f426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -74,24 +74,14 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { FlatEncoder[Map[Int, Map[String, Int]]], "map of map") // Kryo encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.kryo[String]), - "kryo string") - encodeDecodeTest( - new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), - "kryo object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") + encodeDecodeTest(new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") // Java encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.javaSerialization[String]), - "java string") - encodeDecodeTest( - new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), - "java object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") + encodeDecodeTest(new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") } /** For testing Kryo serialization based encoder. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index d9b22506fbd3..ce40dd856f67 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -24,6 +24,7 @@ import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; + import org.junit.*; import org.apache.spark.Accumulator; @@ -410,8 +411,8 @@ public String call(Tuple2 value) throws Exception { .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); Assert.assertEquals( Arrays.asList( - new Tuple4("a", 3, 3L, 2L), - new Tuple4("b", 3, 3L, 1L)), + new Tuple4<>("a", 3, 3L, 2L), + new Tuple4<>("b", 3, 3L, 1L)), agged2.collectAsList()); } @@ -437,4 +438,74 @@ public Integer finish(Integer reduction) { return reduction; } } + + public static class KryoSerializable { + String value; + + KryoSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((KryoSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + public static class JavaSerializable implements Serializable { + String value; + + JavaSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((JavaSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + @Test + public void testKryoEncoder() { + Encoder encoder = Encoders.kryo(KryoSerializable.class); + List data = Arrays.asList( + new KryoSerializable("hello"), new KryoSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + @Test + public void testJavaEncoder() { + Encoder encoder = Encoders.javaSerialization(JavaSerializable.class); + List data = Arrays.asList( + new JavaSerializable("hello"), new JavaSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + /** + * For testing error messages when creating an encoder on a private class. This is done + * here since we cannot create truly private classes in Scala. + */ + private static class PrivateClassTest { } + + @Test(expected = UnsupportedOperationException.class) + public void testJavaEncoderErrorMessageForPrivateClass() { + Encoders.javaSerialization(PrivateClassTest.class); + } + + @Test(expected = UnsupportedOperationException.class) + public void testKryoEncoderErrorMessageForPrivateClass() { + Encoders.kryo(PrivateClassTest.class); + } } From 6d0848b53bbe6c5acdcf5c033cd396b1ae6e293d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 18 Nov 2015 18:38:45 -0800 Subject: [PATCH 1442/1824] [SPARK-11787][SQL] Improve Parquet scan performance when using flat schemas. This patch adds an alternate to the Parquet RecordReader from the parquet-mr project that is much faster for flat schemas. Instead of using the general converter mechanism from parquet-mr, this directly uses the lower level APIs from parquet-columnar and a customer RecordReader that directly assembles into UnsafeRows. This is optionally disabled and only used for supported schemas. Using the tpcds store sales table and doing a sum of increasingly more columns, the results are: For 1 Column: Before: 11.3M rows/second After: 18.2M rows/second For 2 Columns: Before: 7.2M rows/second After: 11.2M rows/second For 5 Columns: Before: 2.9M rows/second After: 4.5M rows/second Author: Nong Li Closes #9774 from nongli/parquet. --- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 41 +- .../sql/catalyst/expressions/UnsafeRow.java | 9 + .../expressions/codegen/BufferHolder.java | 32 +- .../expressions/codegen/UnsafeRowWriter.java | 20 +- .../SpecificParquetRecordReaderBase.java | 240 +++++++ .../parquet/UnsafeRowParquetRecordReader.java | 593 ++++++++++++++++++ .../parquet/CatalystRowConverter.scala | 48 +- .../parquet/ParquetFilterSuite.scala | 4 +- 8 files changed, 944 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 264dae7f3908..4d176332b69c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -20,8 +20,6 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import scala.reflect.ClassTag - import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,10 +28,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} + +import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( @@ -96,6 +96,11 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient protected val jobId = new JobID(jobTrackerId, id) + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + protected val enableUnsafeRowParquetReader: Boolean = + sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) val inputFormat = inputFormatClass.newInstance @@ -150,9 +155,31 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - private[this] var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private[this] var reader: RecordReader[Void, V] = null + + /** + * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this + * fails (for example, unsupported schema), try with the normal reader. + * TODO: plumb this through a different way? + */ + if (enableUnsafeRowParquetReader && + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + // TODO: move this class to sql.execution and remove this. + reader = Utils.classForName( + "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") + .newInstance().asInstanceOf[RecordReader[Void, V]] + try { + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } catch { + case e: Exception => reader = null + } + } + + if (reader == null) { + reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5ba14ebdb62a..33769363a0ed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -178,6 +178,15 @@ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } + /** + * Updates this UnsafeRow preserving the number of fields. + * @param buf byte array to point to + * @param sizeInBytes the number of bytes valid in the byte array + */ + public void pointTo(byte[] buf, int sizeInBytes) { + pointTo(buf, numFields, sizeInBytes); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 9c9468678065..d26b1b187c27 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -17,19 +17,28 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer used in `GenerateUnsafeProjection`. - * - * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables - * public for ease of use. + * A helper class to manage the row buffer when construct unsafe rows. */ public class BufferHolder { - public byte[] buffer = new byte[64]; + public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; - public void grow(int neededSize) { + public BufferHolder() { + this(64); + } + + public BufferHolder(int size) { + buffer = new byte[size]; + } + + /** + * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + */ + public void grow(int neededSize, UnsafeRow row) { final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. @@ -41,12 +50,23 @@ public void grow(int neededSize) { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; + if (row != null) { + row.pointTo(buffer, length * 2); + } } } + public void grow(int neededSize) { + grow(neededSize, null); + } + public void reset() { cursor = Platform.BYTE_ARRAY_OFFSET; } + public void resetTo(int offset) { + assert(offset <= buffer.length); + cursor = Platform.BYTE_ARRAY_OFFSET + offset; + } public int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 048b7749d8fb..e227c0dec974 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -35,6 +35,7 @@ public class UnsafeRowWriter { // The offset of the global buffer where we start to write this row. private int startingOffset; private int nullBitsSize; + private UnsafeRow row; public void initialize(BufferHolder holder, int numFields) { this.holder = holder; @@ -43,7 +44,7 @@ public void initialize(BufferHolder holder, int numFields) { // grow the global buffer to make sure it has enough space to write fixed-length data. final int fixedSize = nullBitsSize + 8 * numFields; - holder.grow(fixedSize); + holder.grow(fixedSize, row); holder.cursor += fixedSize; // zero-out the null bits region @@ -52,12 +53,19 @@ public void initialize(BufferHolder holder, int numFields) { } } + public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { + initialize(holder, numFields); + this.row = row; + } + private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); } } + public BufferHolder holder() { return holder; } + public boolean isNullAt(int ordinal) { return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); } @@ -90,7 +98,7 @@ public void alignToWords(int numBytes) { if (remainder > 0) { final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes); + holder.grow(paddingBytes, row); for (int i = 0; i < paddingBytes; i++) { Platform.putByte(holder.buffer, holder.cursor, (byte) 0); @@ -153,7 +161,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { } } else { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // zero-out the bytes Platform.putLong(holder.buffer, holder.cursor, 0L); @@ -185,7 +193,7 @@ public void write(int ordinal, UTF8String input) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -206,7 +214,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -222,7 +230,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { public void write(int ordinal, CalendarInterval input) { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // Write the months and microseconds fields of Interval to the variable length portion. Platform.putLong(holder.buffer, holder.cursor, input.months); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java new file mode 100644 index 000000000000..2ed30c1f5a8d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; +import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; +import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.hadoop.BadConfigurationException; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetInputFormat; +import org.apache.parquet.hadoop.ParquetInputSplit; +import org.apache.parquet.hadoop.api.InitContext; +import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.ConfigurationUtil; +import org.apache.parquet.schema.MessageType; + +/** + * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. + * This class handles computing row groups, filtering on them, setting up the column readers, + * etc. + * This is heavily based on parquet-mr's RecordReader. + * TODO: move this to the parquet-mr project. There are performance benefits of doing it + * this way, albeit at a higher cost to implement. This base class is reusable. + */ +public abstract class SpecificParquetRecordReaderBase extends RecordReader { + protected Path file; + protected MessageType fileSchema; + protected MessageType requestedSchema; + protected ReadSupport readSupport; + + /** + * The total number of rows this RecordReader will eventually read. The sum of the + * rows of all the row groups. + */ + protected long totalRowCount; + + protected ParquetFileReader reader; + + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + Configuration configuration = taskAttemptContext.getConfiguration(); + ParquetInputSplit split = (ParquetInputSplit)inputSplit; + this.file = split.getPath(); + long[] rowGroupOffsets = split.getRowGroupOffsets(); + + ParquetMetadata footer; + List blocks; + + // if task.side.metadata is set, rowGroupOffsets is null + if (rowGroupOffsets == null) { + // then we need to apply the predicate push down filter + footer = readFooter(configuration, file, range(split.getStart(), split.getEnd())); + MessageType fileSchema = footer.getFileMetaData().getSchema(); + FilterCompat.Filter filter = getFilter(configuration); + blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema); + } else { + // otherwise we find the row groups that were selected on the client + footer = readFooter(configuration, file, NO_FILTER); + Set offsets = new HashSet<>(); + for (long offset : rowGroupOffsets) { + offsets.add(offset); + } + blocks = new ArrayList<>(); + for (BlockMetaData block : footer.getBlocks()) { + if (offsets.contains(block.getStartingPos())) { + blocks.add(block); + } + } + // verify we found them all + if (blocks.size() != rowGroupOffsets.length) { + long[] foundRowGroupOffsets = new long[footer.getBlocks().size()]; + for (int i = 0; i < foundRowGroupOffsets.length; i++) { + foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos(); + } + // this should never happen. + // provide a good error message in case there's a bug + throw new IllegalStateException( + "All the offsets listed in the split should be found in the file." + + " expected: " + Arrays.toString(rowGroupOffsets) + + " found: " + blocks + + " out of: " + Arrays.toString(foundRowGroupOffsets) + + " in range " + split.getStart() + ", " + split.getEnd()); + } + } + MessageType fileSchema = footer.getFileMetaData().getSchema(); + Map fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); + this.readSupport = getReadSupportInstance( + (Class>) getReadSupportClass(configuration)); + ReadSupport.ReadContext readContext = readSupport.init(new InitContext( + taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); + this.requestedSchema = readContext.getRequestedSchema(); + this.fileSchema = fileSchema; + this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + reader = null; + } + } + + /** + * Utility classes to abstract over different way to read ints with different encodings. + * TODO: remove this layer of abstraction? + */ + abstract static class IntIterator { + abstract int nextInt() throws IOException; + } + + protected static final class ValuesReaderIntIterator extends IntIterator { + ValuesReader delegate; + + public ValuesReaderIntIterator(ValuesReader delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInteger(); + } + } + + protected static final class RLEIntIterator extends IntIterator { + RunLengthBitPackingHybridDecoder delegate; + + public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInt(); + } + } + + protected static final class NullIntIterator extends IntIterator { + @Override + int nextInt() throws IOException { return 0; } + } + + /** + * Creates a reader for definition and repetition levels, returning an optimized one if + * the levels are not needed. + */ + static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + ColumnDescriptor descriptor) throws IOException { + try { + if (maxLevel == 0) return new NullIntIterator(); + return new RLEIntIterator( + new RunLengthBitPackingHybridDecoder( + BytesUtils.getWidthFromMaxInt(maxLevel), + new ByteArrayInputStream(bytes.toByteArray()))); + } catch (IOException e) { + throw new IOException("could not read levels in page for col " + descriptor, e); + } + } + + private static Map> toSetMultiMap(Map map) { + Map> setMultiMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Set set = new HashSet<>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + private static Class getReadSupportClass(Configuration configuration) { + return ConfigurationUtil.getClassFromConfig(configuration, + ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); + } + + /** + * @param readSupportClass to instantiate + * @return the configured read support + */ + private static ReadSupport getReadSupportInstance( + Class> readSupportClass){ + try { + return readSupportClass.newInstance(); + } catch (InstantiationException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } catch (IllegalAccessException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java new file mode 100644 index 000000000000..8a92e489ccb7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -0,0 +1,593 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; +import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.parquet.column.ValuesType.VALUES; + +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.page.DataPage; +import org.apache.parquet.column.page.DataPageV1; +import org.apache.parquet.column.page.DataPageV2; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +/** + * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. + * + * This is somewhat based on parquet-mr's ColumnReader. + * + * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. + * All of these can be handled efficiently and easily with codegen. + */ +public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { + /** + * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this + * batch is used up (batchIdx == numBatched), we populated the batch. + */ + private UnsafeRow[] rows = new UnsafeRow[64]; + private int batchIdx = 0; + private int numBatched = 0; + + /** + * Used to write variable length columns. Same length as `rows`. + */ + private UnsafeRowWriter[] rowWriters = null; + /** + * True if the row contains variable length fields. + */ + private boolean containsVarLenFields; + + /** + * The number of bytes in the fixed length portion of the row. + */ + private int fixedSizeBytes; + + /** + * For each request column, the reader to read this column. + * columnsReaders[i] populated the UnsafeRow's attribute at i. + */ + private ColumnReader[] columnReaders; + + /** + * The number of rows that have been returned. + */ + private long rowsReturned; + + /** + * The number of rows that have been reading, including the current in flight row group. + */ + private long totalCountLoadedSoFar = 0; + + /** + * For each column, the annotated original type. + */ + private OriginalType[] originalTypes; + + /** + * The default size for varlen columns. The row grows as necessary to accommodate the + * largest column. + */ + private static final int DEFAULT_VAR_LEN_SIZE = 32; + + /** + * Implementation of RecordReader API. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + super.initialize(inputSplit, taskAttemptContext); + + /** + * Check that the requested schema is supported. + */ + if (requestedSchema.getFieldCount() == 0) { + // TODO: what does this mean? + throw new IOException("Empty request schema not supported."); + } + int numVarLenFields = 0; + originalTypes = new OriginalType[requestedSchema.getFieldCount()]; + for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { + Type t = requestedSchema.getFields().get(i); + if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { + throw new IOException("Complex types not supported."); + } + PrimitiveType primitiveType = t.asPrimitiveType(); + + originalTypes[i] = t.getOriginalType(); + + // TODO: Be extremely cautious in what is supported. Expand this. + if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + throw new IOException("Unsupported type: " + t); + } + if (originalTypes[i] == OriginalType.DECIMAL && + primitiveType.getDecimalMetadata().getPrecision() > + CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + throw new IOException("Decimal with high precision is not supported."); + } + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { + throw new IOException("Int96 not supported."); + } + ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i)); + if (!fd.equals(requestedSchema.getColumns().get(i))) { + throw new IOException("Schema evolution not supported."); + } + + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) { + ++numVarLenFields; + } + } + + /** + * Initialize rows and rowWriters. These objects are reused across all rows in the relation. + */ + int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); + rowByteSize += 8 * requestedSchema.getFieldCount(); + fixedSizeBytes = rowByteSize; + rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; + containsVarLenFields = numVarLenFields > 0; + rowWriters = new UnsafeRowWriter[rows.length]; + + for (int i = 0; i < rows.length; ++i) { + rows[i] = new UnsafeRow(); + rowWriters[i] = new UnsafeRowWriter(); + BufferHolder holder = new BufferHolder(rowByteSize); + rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), + holder.buffer.length); + } + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + + /** + * Decodes a batch of values into `rows`. This function is the hot path. + */ + private boolean loadBatch() throws IOException { + // no more records left + if (rowsReturned >= totalRowCount) { return false; } + checkEndOfRowGroup(); + + int num = (int)Math.min(rows.length, totalCountLoadedSoFar - rowsReturned); + rowsReturned += num; + + if (containsVarLenFields) { + for (int i = 0; i < rowWriters.length; ++i) { + rowWriters[i].holder().resetTo(fixedSizeBytes); + } + } + + for (int i = 0; i < columnReaders.length; ++i) { + switch (columnReaders[i].descriptor.getType()) { + case BOOLEAN: + decodeBooleanBatch(i, num); + break; + case INT32: + if (originalTypes[i] == OriginalType.DECIMAL) { + decodeIntAsDecimalBatch(i, num); + } else { + decodeIntBatch(i, num); + } + break; + case INT64: + Preconditions.checkState(originalTypes[i] == null + || originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeLongBatch(i, num); + break; + case FLOAT: + decodeFloatBatch(i, num); + break; + case DOUBLE: + decodeDoubleBatch(i, num); + break; + case BINARY: + decodeBinaryBatch(i, num); + break; + case FIXED_LEN_BYTE_ARRAY: + Preconditions.checkState(originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeFixedLenArrayAsDecimalBatch(i, num); + break; + case INT96: + throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); + } + numBatched = num; + batchIdx = 0; + } + return true; + } + + private void decodeBooleanBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setBoolean(col, columnReaders[col].nextBoolean()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setInt(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntAsDecimalBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + // Since this is stored as an INT, it is always a compact decimal. Just set it as a long. + rows[n].setLong(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeLongBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setLong(col, columnReaders[col].nextLong()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFloatBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setFloat(col, columnReaders[col].nextFloat()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeDoubleBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setDouble(col, columnReaders[col].nextDouble()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeBinaryBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); + int len = bytes.limit() - bytes.position(); + if (originalTypes[col] == OriginalType.UTF8) { + UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len); + rowWriters[n].write(col, str); + } else { + rowWriters[n].write(col, bytes.array(), bytes.position(), len); + } + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOException { + PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); + int precision = type.getDecimalMetadata().getPrecision(); + int scale = type.getDecimalMetadata().getScale(); + Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + "Unsupported precision."); + + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + Binary v = columnReaders[col].nextBinary(); + // Constructs a `Decimal` with an unscaled `Long` value if possible. + long unscaled = CatalystRowConverter.binaryToUnscaledLong(v); + rows[n].setDecimal(col, Decimal.apply(unscaled, precision, scale), precision); + } else { + rows[n].setNullAt(col); + } + } + } + + /** + * + * Decoder to return values from a single column. + */ + private static final class ColumnReader { + /** + * Total number of values read. + */ + private long valuesRead; + + /** + * value that indicates the end of the current page. That is, + * if valuesRead == endOfPageValueCount, we are at the end of the page. + */ + private long endOfPageValueCount; + + /** + * The dictionary, if this column has dictionary encoding. + */ + private final Dictionary dictionary; + + /** + * If true, the current page is dictionary encoded. + */ + private boolean useDictionary; + + /** + * Maximum definition level for this column. + */ + private final int maxDefLevel; + + /** + * Repetition/Definition/Value readers. + */ + private IntIterator repetitionLevelColumn; + private IntIterator definitionLevelColumn; + private ValuesReader dataColumn; + + /** + * Total number of values in this column (in this row group). + */ + private final long totalValueCount; + + /** + * Total values in the current page. + */ + private int pageValueCount; + + private final PageReader pageReader; + private final ColumnDescriptor descriptor; + + public ColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + throws IOException { + this.descriptor = descriptor; + this.pageReader = pageReader; + this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + + DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); + if (dictionaryPage != null) { + try { + this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage); + this.useDictionary = true; + } catch (IOException e) { + throw new IOException("could not decode the dictionary for " + descriptor, e); + } + } else { + this.dictionary = null; + this.useDictionary = false; + } + this.totalValueCount = pageReader.getTotalValueCount(); + if (totalValueCount == 0) { + throw new IOException("totalValueCount == 0"); + } + } + + /** + * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned. + */ + public boolean nextBoolean() { + if (!useDictionary) { + return dataColumn.readBoolean(); + } else { + return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId()); + } + } + + public int nextInt() { + if (!useDictionary) { + return dataColumn.readInteger(); + } else { + return dictionary.decodeToInt(dataColumn.readValueDictionaryId()); + } + } + + public long nextLong() { + if (!useDictionary) { + return dataColumn.readLong(); + } else { + return dictionary.decodeToLong(dataColumn.readValueDictionaryId()); + } + } + + public float nextFloat() { + if (!useDictionary) { + return dataColumn.readFloat(); + } else { + return dictionary.decodeToFloat(dataColumn.readValueDictionaryId()); + } + } + + public double nextDouble() { + if (!useDictionary) { + return dataColumn.readDouble(); + } else { + return dictionary.decodeToDouble(dataColumn.readValueDictionaryId()); + } + } + + public Binary nextBinary() { + if (!useDictionary) { + return dataColumn.readBytes(); + } else { + return dictionary.decodeToBinary(dataColumn.readValueDictionaryId()); + } + } + + /** + * Advances to the next value. Returns true if the value is non-null. + */ + private boolean next() throws IOException { + if (valuesRead >= endOfPageValueCount) { + if (valuesRead >= totalValueCount) { + // How do we get here? Throw end of stream exception? + return false; + } + readPage(); + } + ++valuesRead; + // TODO: Don't read for flat schemas + //repetitionLevel = repetitionLevelColumn.nextInt(); + return definitionLevelColumn.nextInt() == maxDefLevel; + } + + private void readPage() throws IOException { + DataPage page = pageReader.readPage(); + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }); + } + + private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int valueCount) + throws IOException { + this.pageValueCount = valueCount; + this.endOfPageValueCount = valuesRead + pageValueCount; + if (dataEncoding.usesDictionary()) { + if (dictionary == null) { + throw new IOException( + "could not read page in col " + descriptor + + " as the dictionary was missing for encoding " + dataEncoding); + } + this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( + descriptor, VALUES, dictionary); + this.useDictionary = true; + } else { + this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); + this.useDictionary = false; + } + + try { + dataColumn.initFromPage(pageValueCount, bytes, offset); + } catch (IOException e) { + throw new IOException("could not read page in col " + descriptor, e); + } + } + + private void readPageV1(DataPageV1 page) throws IOException { + ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); + ValuesReader dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL); + this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); + try { + byte[] bytes = page.getBytes().toByteArray(); + rlReader.initFromPage(pageValueCount, bytes, 0); + int next = rlReader.getNextOffset(); + dlReader.initFromPage(pageValueCount, bytes, next); + next = dlReader.getNextOffset(); + initDataReader(page.getValueEncoding(), bytes, next, page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + + private void readPageV2(DataPageV2 page) throws IOException { + this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), + page.getRepetitionLevels(), descriptor); + this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), + page.getDefinitionLevels(), descriptor); + try { + initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0, + page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + } + + private void checkEndOfRowGroup() throws IOException { + if (rowsReturned != totalCountLoadedSoFar) return; + PageReadStore pages = reader.readNextRowGroup(); + if (pages == null) { + throw new IOException("expecting more rows but reached last block. Read " + + rowsReturned + " out of " + totalRowCount); + } + List columns = requestedSchema.getColumns(); + columnReaders = new ColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { + columnReaders[i] = new ColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); + } + totalCountLoadedSoFar += pages.getRowCount(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 1f653cd3d3cb..94298fae2d69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -370,35 +370,13 @@ private[parquet] class CatalystRowConverter( protected def decimalFromBinary(value: Binary): Decimal = { if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. - val unscaled = binaryToUnscaledLong(value) + val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) } } - - private def binaryToUnscaledLong(binary: Binary): Long = { - // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here - // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without - // copying it. - val buffer = binary.toByteBuffer - val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() - - var unscaled = 0L - var i = start - - while (i < end) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } - - val bits = 8 * (end - start) - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - unscaled - } } private class CatalystIntDictionaryAwareDecimalConverter( @@ -658,3 +636,27 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = elementConverter.start() } } + +private[parquet] object CatalystRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here + // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without + // copying it. + val buffer = binary.toByteBuffer + val bytes = buffer.array() + val start = buffer.position() + val end = buffer.limit() + + var unscaled = 0L + var i = start + + while (i < end) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * (end - start) + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + unscaled + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 458786f77af3..c8028a5ef552 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,7 +337,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does + // not do row by row filtering (and we probably don't want to push that). + ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => From 9c0654d36c6d171dd273850c2cc2f415cc2a5a6b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 18:41:40 -0800 Subject: [PATCH 1443/1824] Revert "[SPARK-11544][SQL] sqlContext doesn't use PathFilter" This reverts commit 54db79702513e11335c33bcf3a03c59e965e6f16. --- .../apache/spark/sql/sources/interfaces.scala | 25 +++---------- .../datasources/json/JsonSuite.scala | 36 ++----------------- 2 files changed, 7 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f9465157c936..b3d3bdf50df6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,8 +21,7 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} -import org.apache.hadoop.mapred.{JobConf, FileInputFormat} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -448,15 +447,9 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + logInfo(s"Listing $qualified on driver") - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) - } else { - Try(fs.listStatus(qualified)).getOrElse(Array.empty) - } + Try(fs.listStatus(qualified)).getOrElse(Array.empty) }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -854,16 +847,8 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(fs.getConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f09b61e83815..6042b1178aff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,27 +19,19 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} -import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class TestFileFilter extends PathFilter { - override def accept(path: Path): Boolean = path.getParent.getName != "p=2" -} - class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1398,28 +1390,4 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } - - test("SPARK-11544 test pathfilter") { - withTempPath { dir => - val path = dir.getCanonicalPath - - val df = sqlContext.range(2) - df.write.json(path + "/p=1") - df.write.json(path + "/p=2") - assert(sqlContext.read.json(path).count() === 4) - - val clonedConf = new Configuration(hadoopConfiguration) - try { - hadoopConfiguration.setClass( - "mapreduce.input.pathFilter.class", - classOf[TestFileFilter], - classOf[PathFilter]) - assert(sqlContext.read.json(path).count() === 2) - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - } - } - } } From 67c75828ff4df2e305bdf5d6be5a11201d1da3f3 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 18:49:46 -0800 Subject: [PATCH 1444/1824] [SPARK-11816][ML] fix some style issue in ML/MLlib examples jira: https://issues.apache.org/jira/browse/SPARK-11816 Currently I only fixed some obvious comments issue like // scalastyle:off println on the bottom. Yet the style in examples is not quite consistent, like only half of the examples are with // Example usage: ./bin/run-example mllib.FPGrowthExample \, Author: Yuhao Yang Closes #9808 from hhbyyh/exampleStyle. --- .../java/org/apache/spark/examples/ml/JavaKMeansExample.java | 2 +- .../apache/spark/examples/ml/AFTSurvivalRegressionExample.scala | 2 +- .../spark/examples/ml/DecisionTreeClassificationExample.scala | 1 + .../spark/examples/ml/DecisionTreeRegressionExample.scala | 1 + .../examples/ml/MultilayerPerceptronClassifierExample.scala | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index be2bf0c7b465..47665ff2b1f3 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -41,7 +41,7 @@ * An example demonstrating a k-means clustering. * Run with *
    - * bin/run-example ml.JavaSimpleParamsExample  
    + * bin/run-example ml.JavaKMeansExample  
      * 
    */ public class JavaKMeansExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala index 5da285e83681..f4b3613ccb94 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -59,4 +59,4 @@ object AFTSurvivalRegressionExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index ff8a0a90f1e4..db024b5cad93 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -90,3 +90,4 @@ object DecisionTreeClassificationExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index fc402724d215..ad01f55df72b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -78,3 +78,4 @@ object DecisionTreeRegressionExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 146b83c8be49..9c98076bd24b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,4 +66,4 @@ object MultilayerPerceptronClassifierExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println From fc3f77b42d62ca789d0ee07403795978961991c7 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Wed, 18 Nov 2015 19:37:14 -0800 Subject: [PATCH 1445/1824] [SPARK-11614][SQL] serde parameters should be set only when all params are ready see HIVE-7975 and HIVE-12373 With changed semantic of setters in thrift objects in hive, setter should be called only after all parameters are set. It's not problem of current state but will be a problem in some day. Author: navis.ryu Closes #9580 from navis/SPARK-11614. --- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f4d45714fae4..9a981d02ad67 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -804,12 +804,13 @@ private[hive] case class MetastoreRelation val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo sd.setSerdeInfo(serdeInfo) + // maps and lists should be set only after all elements are ready (see HIVE-7975) serdeInfo.setSerializationLib(p.storage.serde) val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Partition(hiveQlTable, tPartition) } From d02d5b9295b169c3ebb0967453b2835edb8a121f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 21:44:01 -0800 Subject: [PATCH 1446/1824] [SPARK-11842][ML] Small cleanups to existing Readers and Writers Updates: * Add repartition(1) to save() methods' saving of data for LogisticRegressionModel, LinearRegressionModel. * Strengthen privacy to class and companion object for Writers and Readers * Change LogisticRegressionSuite read/write test to fit intercept * Add Since versions for read/write methods in Pipeline, LogisticRegression * Switch from hand-written class names in Readers to using getClass CC: mengxr CC: yanboliang Would you mind taking a look at this PR? mengxr might not be able to soon. Thank you! Author: Joseph K. Bradley Closes #9829 from jkbradley/ml-io-cleanups. --- .../scala/org/apache/spark/ml/Pipeline.scala | 22 +++++++++++++------ .../classification/LogisticRegression.scala | 19 ++++++++++------ .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 6 ++--- .../ml/regression/LinearRegression.scala | 4 ++-- .../LogisticRegressionSuite.scala | 2 +- 10 files changed, 38 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index b0f22e042ec5..6f15b37abcb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -27,7 +27,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.MLReader import org.apache.spark.ml.util.MLWriter @@ -174,16 +174,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + @Since("1.6.0") override def write: MLWriter = new Pipeline.PipelineWriter(this) } +@Since("1.6.0") object Pipeline extends MLReadable[Pipeline] { + @Since("1.6.0") override def read: MLReader[Pipeline] = new PipelineReader + @Since("1.6.0") override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { + private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,10 +195,10 @@ object Pipeline extends MLReadable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends MLReader[Pipeline] { + private class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.Pipeline" + private val className = classOf[Pipeline].getName override def load(path: String): Pipeline = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) @@ -333,18 +337,22 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + @Since("1.6.0") override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } +@Since("1.6.0") object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite + @Since("1.6.0") override def read: MLReader[PipelineModel] = new PipelineModelReader + @Since("1.6.0") override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { + private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,10 +360,10 @@ object PipelineModel extends MLReadable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends MLReader[PipelineModel] { + private class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.PipelineModel" + private val className = classOf[PipelineModel].getName override def load(path: String): PipelineModel = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a3cc49f7f018..418bbdc9a058 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] ( * * This also does not save the [[parent]] currently. */ + @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } +@Since("1.6.0") object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { + @Since("1.6.0") override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader + @Since("1.6.0") override def load(path: String): LogisticRegressionModel = super.load(path) /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ - private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + private[LogisticRegressionModel] + class LogisticRegressionModelWriter(instance: LogisticRegressionModel) extends MLWriter with Logging { private case class Data( @@ -552,15 +557,15 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } - private[classification] class LogisticRegressionModelReader + private class LogisticRegressionModelReader extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + private val className = classOf[LogisticRegressionModel].getName override def load(path: String): LogisticRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) @@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends Serializable { * @return This MultilabelSummarizer */ def add(label: Double, weight: Double = 1.0): this.type = { - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -839,7 +844,7 @@ private class LogisticAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 4969cf42450d..b9e2144c0ad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -266,7 +266,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { - private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + private val className = classOf[CountVectorizerModel].getName override def load(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 0e00ef6f2ee2..f7b0f29a27c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] { private class IDFModelReader extends MLReader[IDFModel] { - private val className = "org.apache.spark.ml.feature.IDFModel" + private val className = classOf[IDFModel].getName override def load(path: String): IDFModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index ed24eabb5044..c2866f5eceff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -210,7 +210,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { - private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + private val className = classOf[MinMaxScalerModel].getName override def load(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1f689c1da1ba..6d545219ebf4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -180,7 +180,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private class StandardScalerModelReader extends MLReader[StandardScalerModel] { - private val className = "org.apache.spark.ml.feature.StandardScalerModel" + private val className = classOf[StandardScalerModel].getName override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 97a2e4f6d6ca..5c40c35eeaa4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -210,7 +210,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private class StringIndexerModelReader extends MLReader[StringIndexerModel] { - private val className = "org.apache.spark.ml.feature.StringIndexerModel" + private val className = classOf[StringIndexerModel].getName override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 795b73c4c212..4d35177ad9b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { + private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] { } } - private[recommendation] class ALSModelReader extends MLReader[ALSModel] { + private class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.recommendation.ALSModel" + private val className = classOf[ALSModel].getName override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7ba1a60edaf7..70ccec766c47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -467,14 +467,14 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + private val className = classOf[LinearRegressionModel].getName override def load(path: String): LinearRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 48ce1bb63068..a9a6ff8a783d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -898,7 +898,7 @@ object LogisticRegressionSuite { "regParam" -> 0.01, "elasticNetParam" -> 0.1, "maxIter" -> 2, // intentionally small - "fitIntercept" -> false, + "fitIntercept" -> true, "tol" -> 0.8, "standardization" -> false, "threshold" -> 0.6 From 1a93323c5bab18ed7e55bf6f7b13aae88cb9721c Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 18 Nov 2015 23:32:49 -0800 Subject: [PATCH 1447/1824] [SPARK-11339][SPARKR] Document the list of functions in R base package that are masked by functions with same name in SparkR Added tests for function that are reported as masked, to make sure the base:: or stats:: function can be called. For those we can't call, added them to SparkR programming guide. It would seem to me `table, sample, subset, filter, cov` not working are not actually expected - I investigated/experimented with them but couldn't get them to work. It looks like as they are defined in base or stats they are missing the S3 generic, eg. ``` > methods("transform") [1] transform,ANY-method transform.data.frame [3] transform,DataFrame-method transform.default see '?methods' for accessing help and source code > methods("subset") [1] subset.data.frame subset,DataFrame-method subset.default [4] subset.matrix see '?methods' for accessing help and source code Warning message: In .S3methods(generic.function, class, parent.frame()) : function 'subset' appears not to be S3 generic; found functions that look like S3 methods ``` Any idea? More information on masking: http://www.ats.ucla.edu/stat/r/faq/referencing_objects.htm http://www.sfu.ca/~sweldon/howTo/guide4.pdf This is what the output doc looks like (minus css): ![image](https://cloud.githubusercontent.com/assets/8969467/11229714/2946e5de-8d4d-11e5-94b0-dda9696b6fdd.png) Author: felixcheung Closes #9785 from felixcheung/rmasked. --- R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 2 +- R/pkg/R/generics.R | 4 ++-- R/pkg/inst/tests/test_mllib.R | 5 +++++ R/pkg/inst/tests/test_sparkSQL.R | 33 +++++++++++++++++++++++++++- docs/sparkr.md | 37 +++++++++++++++++++++++++++++++- 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 34177e3cdd94..06b0108b1389 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2152,7 +2152,7 @@ setMethod("with", }) #' Returns the column types of a DataFrame. -#' +#' #' @name coltypes #' @title Get column types of a DataFrame #' @family dataframe_funcs diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ff0f438045c1..25a1f2210149 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2204,7 +2204,7 @@ setMethod("denseRank", #' @export #' @examples \dontrun{lag(df$c)} setMethod("lag", - signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + signature(x = "characterOrColumn"), function(x, offset, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0dcd05438222..71004a05ba61 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -539,7 +539,7 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) # @rdname subset # @export -setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") }) +setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname agg #' @export @@ -790,7 +790,7 @@ setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname lag #' @export -setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) +setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last #' @export diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index d497ad8c9daa..e0667e5e22c1 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -31,6 +31,11 @@ test_that("glm and predict", { model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") prediction <- predict(model, test) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) }) test_that("glm should work with long formula", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d9a94faff7ac..3f4f319fe745 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -433,6 +433,10 @@ test_that("table() returns a new DataFrame", { expect_is(tabledf, "DataFrame") expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") + + # Test base::table is working + #a <- letters[1:3] + #expect_equal(class(table(a, sample(a))), "table") }) test_that("toRDD() returns an RRDD", { @@ -673,6 +677,9 @@ test_that("sample on a DataFrame", { # Also test sample_frac sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + + # Test base::sample is working + #expect_equal(length(sample(1:12)), 12) }) test_that("select operators", { @@ -753,6 +760,9 @@ test_that("subsetting", { df6 <- subset(df, df$age %in% c(30), c(1,2)) expect_equal(count(df6), 1) expect_equal(columns(df6), c("name", "age")) + + # Test base::subset is working + expect_equal(nrow(subset(airquality, Temp > 80, select = c(Ozone, Temp))), 68) }) test_that("selectExpr() on a DataFrame", { @@ -888,6 +898,9 @@ test_that("column functions", { expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + + # Test that stats::lag is working + expect_equal(length(lag(ldeaths, 12)), 72) }) # test_that("column binary mathfunctions", { @@ -1086,7 +1099,7 @@ test_that("group by, agg functions", { gd3_local <- collect(agg(gd3, var(df8$age))) expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2]) - # make sure base:: or stats::sd, var are working + # Test stats::sd, stats::var are working expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) @@ -1138,6 +1151,9 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered5), 1) filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + + # Test stats::filter is working + #expect_true(is.ts(filter(1:100, rep(1, 3)))) }) test_that("join() and merge() on a DataFrame", { @@ -1284,6 +1300,12 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { expect_is(unioned, "DataFrame") expect_equal(count(intersected), 1) expect_equal(first(intersected)$name, "Andy") + + # Test base::rbind is working + expect_equal(length(rbind(1:4, c = 2, a = 10, 10, deparse.level = 0)), 16) + + # Test base::intersect is working + expect_equal(length(intersect(1:20, 3:23)), 18) }) test_that("withColumn() and withColumnRenamed()", { @@ -1365,6 +1387,9 @@ test_that("describe() and summarize() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[4, "name"], "Andy") expect_equal(collect(stats2)[5, "age"], "30") + + # Test base::summary is working + expect_equal(length(summary(attenu, digits = 4)), 35) }) test_that("dropna() and na.omit() on a DataFrame", { @@ -1448,6 +1473,9 @@ test_that("dropna() and na.omit() on a DataFrame", { expect_identical(expected, actual) actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) expect_identical(expected, actual) + + # Test stats::na.omit is working + expect_equal(nrow(na.omit(data.frame(x = c(0, 10, NA)))), 2) }) test_that("fillna() on a DataFrame", { @@ -1510,6 +1538,9 @@ test_that("cov() and corr() on a DataFrame", { expect_true(abs(result - 1.0) < 1e-12) result <- corr(df, "singles", "doubles", "pearson") expect_true(abs(result - 1.0) < 1e-12) + + # Test stats::cov is working + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) }) test_that("freqItems() on a DataFrame", { diff --git a/docs/sparkr.md b/docs/sparkr.md index a744b76be746..cfb9b41350f4 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,7 +286,7 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). @@ -351,3 +351,38 @@ summary(model) ##Sepal_Width 0.404655 {% endhighlight %}
    + +# R Function Name Conflicts + +When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a +function is masking another function. + +The following functions are masked by the SparkR package: + + + + + + + + + + + + + + + + + + + +
    Masked functionHow to Access
    cov in package:stats
    stats::cov(x, y = NULL, use = "everything",
    +           method = c("pearson", "kendall", "spearman"))
    filter in package:stats
    stats::filter(x, filter, method = c("convolution", "recursive"),
    +              sides = 2, circular = FALSE, init)
    sample in package:basebase::sample(x, size, replace = FALSE, prob = NULL)
    table in package:base
    base::table(...,
    +            exclude = if (useNA == "no") c(NA, NaN),
    +            useNA = c("no", "ifany", "always"),
    +            dnn = list.names(...), deparse.level = 1)
    + +You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) + From f449992009becc8f7c7f06cda522b9beaa1e263c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Nov 2015 10:48:04 -0800 Subject: [PATCH 1448/1824] [SPARK-11849][SQL] Analyzer should replace current_date and current_timestamp with literals We currently rely on the optimizer's constant folding to replace current_timestamp and current_date. However, this can still result in different values for different instances of current_timestamp/current_date if the optimizer is not running fast enough. A better solution is to replace these functions in the analyzer in one shot. Author: Reynold Xin Closes #9833 from rxin/SPARK-11849. --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f00c451b5981..84781cd57f3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -65,9 +65,8 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, - CTESubstitution :: - WindowsSubstitution :: - Nil : _*), + CTESubstitution, + WindowsSubstitution), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -84,7 +83,8 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic), + PullOutNondeterministic, + ComputeCurrentTime), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -1076,7 +1076,7 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. - case plan => plan transformExpressionsUp { + case p => p transformExpressionsUp { case udf @ ScalaUDF(func, _, inputs, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) @@ -1162,3 +1162,20 @@ object CleanupAliases extends Rule[LogicalPlan] { } } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 08586a97411a..e05106995188 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -218,4 +219,41 @@ class AnalysisSuite extends AnalysisTest { udf4) // checkUDF(udf4, expected4) } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = in.analyze.asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = in.analyze.asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } } From 962878843b611fa6229e3ee67bb22e2a4bc283cd Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 19 Nov 2015 11:02:17 -0800 Subject: [PATCH 1449/1824] [SPARK-11840][SQL] Restore the 1.5's behavior of planning a single distinct aggregation. The impact of this change is for a query that has a single distinct column and does not have any grouping expression like `SELECT COUNT(DISTINCT a) FROM table` The plan will be changed from ``` AGG-2 (count distinct) Shuffle to a single reducer Partial-AGG-2 (count distinct) AGG-1 (grouping on a) Shuffle by a Partial-AGG-1 (grouping on 1) ``` to the following one (1.5 uses this) ``` AGG-2 AGG-1 (grouping on a) Shuffle to a single reducer Partial-AGG-1(grouping on a) ``` The first plan is more robust. However, to better benchmark the impact of this change, we should use 1.5's plan and use the conf of `spark.sql.specializeSingleDistinctAggPlanning` to control the plan. Author: Yin Huai Closes #9828 from yhuai/distinctRewriter. --- .../sql/catalyst/analysis/DistinctAggregationRewriter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index c0c960471a61..9c78f6d4cc71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -126,8 +126,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { // When the flag is set to specialize single distinct agg planning, // we will rely on our Aggregation strategy to handle queries with a single - // distinct column and this aggregate operator does have grouping expressions. - distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty) + // distinct column. + distinctAggGroups.size > 1 } else { distinctAggGroups.size >= 1 } From 72d150c271d2b206148fd0917a0def263445121b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 19 Nov 2015 11:57:50 -0800 Subject: [PATCH 1450/1824] [SPARK-11830][CORE] Make NettyRpcEnv bind to the specified host This PR includes the following change: 1. Bind NettyRpcEnv to the specified host 2. Fix the port information in the log for NettyRpcEnv. 3. Fix the service name of NettyRpcEnv. Author: zsxwing Author: Shixiong Zhu Closes #9821 from zsxwing/SPARK-11830. --- .../src/main/scala/org/apache/spark/SparkEnv.scala | 9 ++++++++- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 7 +++---- .../org/apache/spark/network/TransportContext.java | 8 +++++++- .../spark/network/server/TransportServer.java | 14 ++++++++++---- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4474a83bedbd..88df27f733f2 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -258,8 +258,15 @@ object SparkEnv extends Logging { if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem } else { + val actorSystemPort = if (port == 0) 0 else rpcEnv.address.port + 1 // Create a ActorSystem for legacy codes - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1 + AkkaUtils.createActorSystem( + actorSystemName + "ActorSystem", + hostname, + actorSystemPort, + conf, + securityManager + )._1 } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3e0c49796950..3ce359868039 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -102,7 +102,7 @@ private[netty] class NettyRpcEnv( } else { java.util.Collections.emptyList() } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(host, port, bootstraps) dispatcher.registerRpcEndpoint( RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } @@ -337,10 +337,10 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(actualPort) - (nettyEnv, actualPort) + (nettyEnv, nettyEnv.address.port) } try { - Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1 } catch { case NonFatal(e) => nettyEnv.shutdown() @@ -370,7 +370,6 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { * @param conf Spark configuration. * @param endpointAddress The address where the endpoint is listening. * @param nettyEnv The RpcEnv associated with this ref. - * @param local Whether the referenced endpoint lives in the same process. */ private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 1b64b863a9fe..238710d17249 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -94,7 +94,13 @@ public TransportClientFactory createClientFactory() { /** Create a server which will attempt to bind to a specific port. */ public TransportServer createServer(int port, List bootstraps) { - return new TransportServer(this, port, rpcHandler, bootstraps); + return new TransportServer(this, null, port, rpcHandler, bootstraps); + } + + /** Create a server which will attempt to bind to a specific host and port. */ + public TransportServer createServer( + String host, int port, List bootstraps) { + return new TransportServer(this, host, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index f4fadb1ee3b8..baae235e0220 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -55,9 +55,13 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + /** + * Creates a TransportServer that binds to the given host and the given port, or to any available + * if 0. If you don't want to bind to any special host, set "hostToBind" to null. + * */ public TransportServer( TransportContext context, + String hostToBind, int portToBind, RpcHandler appRpcHandler, List bootstraps) { @@ -67,7 +71,7 @@ public TransportServer( this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); try { - init(portToBind); + init(hostToBind, portToBind); } catch (RuntimeException e) { JavaUtils.closeQuietly(this); throw e; @@ -81,7 +85,7 @@ public int getPort() { return port; } - private void init(int portToBind) { + private void init(String hostToBind, int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = @@ -120,7 +124,9 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); + InetSocketAddress address = hostToBind == null ? + new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind); + channelFuture = bootstrap.bind(address); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); From 276a7e130252c0e7aba702ae5570b3c4f424b23b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:45:04 -0800 Subject: [PATCH 1451/1824] [SPARK-11633][SQL] LogicalRDD throws TreeNode Exception : Failed to Copy Node When handling self joins, the implementation did not consider the case insensitivity of HiveContext. It could cause an exception as shown in the JIRA: ``` TreeNodeException: Failed to copy node. ``` The fix is low risk. It avoids unnecessary attribute replacement. It should not affect the existing behavior of self joins. Also added the test case to cover this case. Author: gatorsmile Closes #9762 from gatorsmile/joinMakeCopy. --- .../apache/spark/sql/execution/ExistingRDD.scala | 4 ++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 62620ec642c7..623348f6768a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -74,6 +74,10 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + sqlContext :: Nil + } + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6399b0165c4c..dd6d06512ff6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1110,6 +1110,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + // This test case is to verify a bug when making a new instance of LogicalRDD. + test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) + val df = sqlContext.createDataFrame( + rdd, + new StructType().add("f1", IntegerType).add("f2", IntegerType), + needsConversion = false).select($"F1", $"f2".as("f2")) + val df1 = df.as("a") + val df2 = df.as("b") + checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil) + } + } + test("SPARK-10656: completely support special chars") { val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") checkAnswer(df.select(df("*")), Row(1, "a")) From 7d4aba18722727c85893ad8d8f07d4494665dcfc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:46:36 -0800 Subject: [PATCH 1452/1824] [SPARK-11848][SQL] Support EXPLAIN in DataSet APIs When debugging DataSet API, I always need to print the logical and physical plans. I am wondering if we should provide a simple API for EXPLAIN? Author: gatorsmile Closes #9832 from gatorsmile/explainDS. --- .../org/apache/spark/sql/DataFrame.scala | 23 +------------------ .../spark/sql/execution/Queryable.scala | 21 +++++++++++++++++ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3ba4ba18d212..98358127e270 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -308,27 +308,6 @@ class DataFrame private[sql]( def printSchema(): Unit = println(schema.treeString) // scalastyle:on println - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - withPlan(explain).queryExecution.executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } - - /** - * Only prints the physical plan to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) - /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 9ca383896a09..e86a52c149a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType import scala.util.control.NonFatal @@ -25,6 +26,7 @@ import scala.util.control.NonFatal private[sql] trait Queryable { def schema: StructType def queryExecution: QueryExecution + def sqlContext: SQLContext override def toString: String = { try { @@ -34,4 +36,23 @@ private[sql] trait Queryable { s"Invalid tree; ${e.getMessage}:\n$queryExecution" } } + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Only prints the physical plan to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(): Unit = explain(extended = false) } From 47d1c2325caaf9ffe31695b6fff529314b8582f7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Nov 2015 12:54:25 -0800 Subject: [PATCH 1453/1824] [SPARK-11750][SQL] revert SPARK-11727 and code clean up After some experiment, I found it's not convenient to have separate encoder builders: `FlatEncoder` and `ProductEncoder`. For example, when create encoders for `ScalaUDF`, we have no idea if the type `T` is flat or not. So I revert the splitting change in https://github.com/apache/spark/pull/9693, while still keeping the bug fixes and tests. Author: Wenchen Fan Closes #9726 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/Encoder.scala | 16 +- .../spark/sql/catalyst/ScalaReflection.scala | 354 +++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 19 +- .../sql/catalyst/encoders/FlatEncoder.scala | 50 -- .../catalyst/encoders/ProductEncoder.scala | 452 ------------------ .../sql/catalyst/encoders/RowEncoder.scala | 12 +- .../sql/catalyst/expressions/objects.scala | 7 +- .../sql/catalyst/ScalaReflectionSuite.scala | 68 --- .../encoders/ExpressionEncoderSuite.scala | 218 ++++++++- .../catalyst/encoders/FlatEncoderSuite.scala | 99 ---- .../encoders/ProductEncoderSuite.scala | 156 ------ .../org/apache/spark/sql/GroupedDataset.scala | 4 +- .../org/apache/spark/sql/SQLImplicits.scala | 23 +- .../org/apache/spark/sql/functions.scala | 4 +- 14 files changed, 364 insertions(+), 1118 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index d54f2854fb33..86bb53645903 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -45,14 +45,14 @@ trait Encoder[T] extends Serializable { */ object Encoders { - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 59ccf356f2c4..33ae700706da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -50,39 +50,29 @@ object ScalaReflection extends ScalaReflection { * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ - def dataTypeFor(tpe: `Type`): DataType = tpe match { - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType - case t if t <:< localTypeOf[Array[Byte]] => BinaryType - case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName - className match { - case "scala.Array" => - val TypeRef(_, _, Seq(arrayType)) = tpe - val cls = arrayType match { - case t if t <:< definitions.IntTpe => classOf[Array[Int]] - case t if t <:< definitions.LongTpe => classOf[Array[Long]] - case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] - case t if t <:< definitions.FloatTpe => classOf[Array[Float]] - case t if t <:< definitions.ShortTpe => classOf[Array[Short]] - case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] - case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] - case other => - // There is probably a better way to do this, but I couldn't find it... - val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls - java.lang.reflect.Array.newInstance(elementType, 1).getClass + def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) - } - ObjectType(cls) - case other => - val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - ObjectType(clazz) - } + private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + tpe match { + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType + case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case _ => + val className: String = tpe.erasure.typeSymbol.asClass.fullName + className match { + case "scala.Array" => + val TypeRef(_, _, Seq(elementType)) = tpe + arrayClassFor(elementType) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) + } + } } /** @@ -90,7 +80,7 @@ object ScalaReflection extends ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - def arrayClassFor(tpe: `Type`): DataType = { + private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -108,6 +98,15 @@ object ScalaReflection extends ScalaReflection { ObjectType(cls) } + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => true + case _ => false + } + /** * Returns an expression that can be used to construct an object of type `T` given an input * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes @@ -116,63 +115,33 @@ object ScalaReflection extends ScalaReflection { * * When used on a primitive type, the constructor will instead default to extracting the value * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling unbind/bind with a new schema. + * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) + def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None) private def constructorFor( tpe: `Type`, path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = - path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = - path - .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path + .map(p => GetInternalRowField(p, ordinal, dataType)) + .getOrElse(BoundReference(ordinal, dataType, false)) - /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => - getPath + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - val boxedType = optType match { - // For primitive types we must manually box the primitive value. - case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer]) - case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long]) - case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double]) - case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float]) - case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short]) - case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte]) - case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean]) - case _ => None - } - - boxedType.map { boxedType => - val objectType = ObjectType(boxedType) - WrapOption( - objectType, - NewInstance( - boxedType, - getPath :: Nil, - propagateNull = true, - objectType)) - }.getOrElse { - val className: String = optType.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) - val objectType = ObjectType(cls) - - WrapOption(objectType, constructorFor(optType, path)) - } + WrapOption(constructorFor(optType, path)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -231,11 +200,11 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.math.BigDecimal] => Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") @@ -248,57 +217,52 @@ object ScalaReflection extends ScalaReflection { } primitiveMethod.map { method => - Invoke(getPath, method, dataTypeFor(t)) + Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { - val returnType = dataTypeFor(t) Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), "array", - returnType) + arrayClassFor(elementType)) } + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val arrayData = + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - - val primitiveMethodKey = keyType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } val keyData = Invoke( MapObjects( p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(keyDataType)), - keyDataType), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), "array", ObjectType(classOf[Array[Any]])) - val primitiveMethodValue = valueType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - val valueData = Invoke( MapObjects( p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(valueDataType)), - valueDataType), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), "array", ObjectType(classOf[Array[Any]])) @@ -308,40 +272,6 @@ object ScalaReflection extends ScalaReflection { "toScalaMap", keyData :: valueData :: Nil) - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - // Avoid boxing when possible by just wrapping a primitive array. - val primitiveMethod = elementType match { - case _ if nullable => None - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - val arrayData = primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), - "array", - arrayClassFor(elementType)) - } - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -361,8 +291,7 @@ object ScalaReflection extends ScalaReflection { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -370,7 +299,7 @@ object ScalaReflection extends ScalaReflection { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) @@ -388,22 +317,19 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } - } } /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe) match { - case s: CreateNamedStruct => s - case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil) - } + extractorFor(inputObject, localTypeOf[T]) match { + case s: CreateNamedStruct => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ - protected def extractorFor( + private def extractorFor( inputObject: Expression, tpe: `Type`): Expression = ScalaReflectionLock.synchronized { if (!inputObject.dataType.isInstanceOf[ObjectType]) { @@ -491,51 +417,36 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (!elementDataType.isInstanceOf[AtomicType]) { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } else { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (dataType.isInstanceOf[AtomicType]) { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - val rawMap = inputObject val keys = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + val values = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) NewInstance( classOf[ArrayBasedMapData], - keys :: values :: Nil, + convertedKeys :: convertedValues :: Nil, dataType = MapType(keyDataType, valueDataType, valueNullable)) case t if t <:< localTypeOf[String] => @@ -558,6 +469,7 @@ object ScalaReflection extends ScalaReflection { DateType, "fromJavaDate", inputObject :: Nil) + case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal, @@ -587,26 +499,24 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) - case t if t <:< definitions.IntTpe => - BoundReference(0, IntegerType, false) - case t if t <:< definitions.LongTpe => - BoundReference(0, LongType, false) - case t if t <:< definitions.DoubleTpe => - BoundReference(0, DoubleType, false) - case t if t <:< definitions.FloatTpe => - BoundReference(0, FloatType, false) - case t if t <:< definitions.ShortTpe => - BoundReference(0, ShortType, false) - case t if t <:< definitions.ByteTpe => - BoundReference(0, ByteType, false) - case t if t <:< definitions.BooleanTpe => - BoundReference(0, BooleanType, false) - case other => throw new UnsupportedOperationException(s"Extractor for type $other is not supported") } } } + + private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = schemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, externalDataType) + } + } } /** @@ -635,8 +545,7 @@ trait ScalaReflection { } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** * Return the Scala Type for `T` in the current classloader mirror. @@ -736,39 +645,4 @@ trait ScalaReflection { assert(methods.length == 1) methods.head.getParameterTypes } - - def typeOfObject: PartialFunction[Any, DataType] = { - // The data type can be determined without ambiguity. - case obj: Boolean => BooleanType - case obj: Array[Byte] => BinaryType - case obj: String => StringType - case obj: UTF8String => StringType - case obj: Byte => ByteType - case obj: Short => ShortType - case obj: Int => IntegerType - case obj: Long => LongType - case obj: Float => FloatType - case obj: Double => DoubleType - case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case obj: Decimal => DecimalType.SYSTEM_DEFAULT - case obj: java.sql.Timestamp => TimestampType - case null => NullType - // For other cases, there is no obvious mapping from the type of the given object to a - // Catalyst data type. A user should provide his/her specific rules - // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of - // objects and then compose the user-defined PartialFunction with this one. - } - - implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { - - /** - * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation - * for the the data in the sequence. - */ - def asRelation: LocalRelation = { - val output = attributesFor[A] - LocalRelation.fromProduct(output, data) - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 456b59500847..6eeba1442c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -30,10 +30,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** - * A factory for constructing encoders that convert objects and primitves to and from the + * A factory for constructing encoders that convert objects and primitives to and from the * internal row format using catalyst expressions and code generation. By default, the * expressions used to retrieve values from an input row when producing an object will be created as * follows: @@ -44,20 +44,21 @@ import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType * to the name `value`. */ object ExpressionEncoder { - def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = { + def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) + val flat = !classOf[Product].isAssignableFrom(cls) - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpression = ScalaReflection.extractorsFor[T](inputObject) - val constructExpression = ScalaReflection.constructorFor[T] + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) + val fromRowExpression = ScalaReflection.constructorFor[T] new ExpressionEncoder[T]( - extractExpression.dataType, + toRowExpression.dataType, flat, - extractExpression.flatten, - constructExpression, + toRowExpression.flatten, + fromRowExpression, ClassTag[T](cls)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala deleted file mode 100644 index 6d307ab13a9f..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} - -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference} -import org.apache.spark.sql.catalyst.ScalaReflection - -object FlatEncoder { - import ScalaReflection.schemaFor - import ScalaReflection.dataTypeFor - - def apply[T : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - assert(!schemaFor(tpe).dataType.isInstanceOf[StructType]) - - val input = BoundReference(0, dataTypeFor(tpe), nullable = true) - val toRowExpression = CreateNamedStruct( - Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil) - val fromRowExpression = ProductEncoder.constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = true, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala deleted file mode 100644 index 2914c6ee790c..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ /dev/null @@ -1,452 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import org.apache.spark.util.Utils -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData} - -import scala.reflect.ClassTag - -object ProductEncoder { - import ScalaReflection.universe._ - import ScalaReflection.mirror - import ScalaReflection.localTypeOf - import ScalaReflection.dataTypeFor - import ScalaReflection.Schema - import ScalaReflection.schemaFor - import ScalaReflection.arrayClassFor - - def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct] - val fromRowExpression = constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = false, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map - - def extractorFor( - inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - tpe match { - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - optType match { - // For primitive types we must manually unbox the value of the object. - case t if t <:< definitions.IntTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), - "intValue", - IntegerType) - case t if t <:< definitions.LongTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), - "longValue", - LongType) - case t if t <:< definitions.DoubleTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), - "doubleValue", - DoubleType) - case t if t <:< definitions.FloatTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), - "floatValue", - FloatType) - case t if t <:< definitions.ShortTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), - "shortValue", - ShortType) - case t if t <:< definitions.ByteTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), - "byteValue", - ByteType) - case t if t <:< definitions.BooleanTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), - "booleanValue", - BooleanType) - - // For non-primitives, we can just extract the object from the Option and then recurse. - case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName - val classObj = Utils.classForName(className) - val optionObjectType = ObjectType(classObj) - - val unwrapped = UnwrapOption(optionObjectType, inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) - } - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - CreateNamedStruct(params.head.flatMap { p => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil - }) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) - - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - DateType, - "fromJavaDate", - inputObject :: Nil) - - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - - case other => - throw new UnsupportedOperationException(s"Encoder for type $other is not supported") - } - } - } - - private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = schemaFor(elementType) - if (RowEncoder.isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), input, externalDataType) - } - } - - def constructorFor( - tpe: `Type`, - path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized { - - /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) - - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) - .getOrElse(BoundReference(ordinal, dataType, false)) - - /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) - - tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath - - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - WrapOption(null, constructorFor(optType, path)) - - case t if t <:< localTypeOf[java.lang.Integer] => - val boxedType = classOf[java.lang.Integer] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Long] => - val boxedType = classOf[java.lang.Long] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Double] => - val boxedType = classOf[java.lang.Double] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Float] => - val boxedType = classOf[java.lang.Float] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Short] => - val boxedType = classOf[java.lang.Short] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Byte] => - val boxedType = classOf[java.lang.Byte] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Boolean] => - val boxedType = classOf[java.lang.Boolean] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Date]), - "toJavaDate", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Timestamp]), - "toJavaTimestamp", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) - - case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - arrayClassFor(elementType)) - } - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val arrayData = - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keyData = - Invoke( - MapObjects( - p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val valueData = - Invoke( - MapObjects( - p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - ArrayBasedMapData, - ObjectType(classOf[Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - - val arguments = params.head.zipWithIndex.map { case (p, i) => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = schemaFor(fieldType).dataType - - // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) - } else { - constructorFor(fieldType, Some(addToPath(fieldName))) - } - } - - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) - - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) - } else { - newInstance - } - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 9bb1602494b6..4cda4824acdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -132,17 +133,8 @@ object RowEncoder { CreateStruct(convertedFields) } - /** - * Returns true if the value of this data type is same between internal and external. - */ - def isNativeType(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => true - case _ => false - } - private def externalDataTypeFor(dt: DataType): DataType = dt match { - case _ if isNativeType(dt) => dt + case _ if ScalaReflection.isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index f865a9408ef4..ef7399e0196a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,7 +24,6 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -300,10 +299,9 @@ case class UnwrapOption( /** * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. - * @param optionType The datatype to be held inside of the Option. * @param child The expression to evaluate and wrap. */ -case class WrapOption(optionType: DataType, child: Expression) +case class WrapOption(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) @@ -316,14 +314,13 @@ case class WrapOption(optionType: DataType, child: Expression) throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val javaType = ctx.javaType(optionType) val inputObject = child.gen(ctx) s""" ${inputObject.code} boolean ${ev.isNull} = false; - scala.Option<$javaType> ${ev.value} = + scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 4ea410d492b0..c2aace1ef238 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -186,74 +186,6 @@ class ScalaReflectionSuite extends SparkFunSuite { nullable = true)) } - test("get data type of a value") { - // BooleanType - assert(BooleanType === typeOfObject(true)) - assert(BooleanType === typeOfObject(false)) - - // BinaryType - assert(BinaryType === typeOfObject("string".getBytes)) - - // StringType - assert(StringType === typeOfObject("string")) - - // ByteType - assert(ByteType === typeOfObject(127.toByte)) - - // ShortType - assert(ShortType === typeOfObject(32767.toShort)) - - // IntegerType - assert(IntegerType === typeOfObject(2147483647)) - - // LongType - assert(LongType === typeOfObject(9223372036854775807L)) - - // FloatType - assert(FloatType === typeOfObject(3.4028235E38.toFloat)) - - // DoubleType - assert(DoubleType === typeOfObject(1.7976931348623157E308)) - - // DecimalType - assert(DecimalType.SYSTEM_DEFAULT === - typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) - - // DateType - assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) - - // TimestampType - assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00"))) - - // NullType - assert(NullType === typeOfObject(null)) - - def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case _ => StringType - } - - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new BigInteger("92233720368547758070"))) - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new java.math.BigDecimal("1.7976931348623157E318"))) - assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) - - def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - } - - intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) - - def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { - case c: Seq[_] => ArrayType(typeOfObject3(c.head)) - } - - assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) - } - test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index cde0364f3dd9..76459b34a484 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,24 +17,234 @@ package org.apache.spark.sql.catalyst.encoders +import java.sql.{Timestamp, Date} import java.util.Arrays import java.util.concurrent.ConcurrentMap +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.types.ArrayType -abstract class ExpressionEncoderSuite extends SparkFunSuite { - val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() +case class RepeatedStruct(s: Seq[PrimitiveData]) - protected def encodeDecodeTest[T]( +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} + +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + +case class SpecificCollection(l: List[Int]) + +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} + +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[JavaSerializable].value + } +} + +class ExpressionEncoderSuite extends SparkFunSuite { + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + + // test flat encoders + encodeDecodeTest(false, "primitive boolean") + encodeDecodeTest(-3.toByte, "primitive byte") + encodeDecodeTest(-3.toShort, "primitive short") + encodeDecodeTest(-3, "primitive int") + encodeDecodeTest(-3L, "primitive long") + encodeDecodeTest(-3.7f, "primitive float") + encodeDecodeTest(-3.7, "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") + // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + + encodeDecodeTest("hello", "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), "binary") + + encodeDecodeTest(Seq(31, -123, 4), "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), "seq of string with null") + encodeDecodeTest(Seq.empty[Int], "empty seq of int") + encodeDecodeTest(Seq.empty[String], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), "array of int") + encodeDecodeTest(Array("abc", "xyz"), "array of string") + encodeDecodeTest(Array("a", null, "x"), "array of string with null") + encodeDecodeTest(Array.empty[Int], "empty array of int") + encodeDecodeTest(Array.empty[String], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + + // Kryo encoders + encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) + encodeDecodeTest(new KryoSerializable(15), "kryo object")( + encoderFor(Encoders.kryo[KryoSerializable])) + + // Java encoders + encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String])) + encodeDecodeTest(new JavaSerializable(15), "java object")( + encoderFor(Encoders.javaSerialization[JavaSerializable])) + + // test product encoders + private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = { + encodeDecodeTest(input, input.getClass.getSimpleName) + } + + case class InnerClass(i: Int) + productTest(InnerClass(1)) + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) + + productTest(OptionalData(None, None, None, None, None, None, None, None)) + + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + + productTest(BoxedData(null, null, null, null, null, null, null)) + + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( + RepeatedData( + Seq(1, 2), + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) + + productTest(("Seq[(String, String)]", + Seq(("a", "b")))) + productTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + productTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + productTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + productTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + productTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + productTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + productTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + productTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + productTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + productTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + productTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + productTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + productTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + productTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + productTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + productTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + // test for ExpressionEncoder.tuple + encodeDecodeTest( + 1 -> 10L, + "tuple with 2 flat encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[Long])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), + "tuple with 2 product encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[(Int, Long)])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), + "tuple with flat encoder and product encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[Int])) + + encodeDecodeTest( + (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), + "tuple with product encoder and flat encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[PrimitiveData])) + + encodeDecodeTest( + (1, (10, 100L)), + "nested tuple encoder") { + val intEnc = ExpressionEncoder[Int] + val longEnc = ExpressionEncoder[Long] + ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + } + + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + outers.put(getClass.getName, this) + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, - encoder: ExpressionEncoder[T], testName: String): Unit = { test(s"encode/decode for $testName: $input") { + val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolve(schema, outers).bind(schema) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala deleted file mode 100644 index 07523d49f426..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import java.sql.{Date, Timestamp} -import org.apache.spark.sql.Encoders - -class FlatEncoderSuite extends ExpressionEncoderSuite { - encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") - encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte") - encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short") - encodeDecodeTest(-3, FlatEncoder[Int], "primitive int") - encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long") - encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float") - encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double") - - encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean") - encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte") - encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short") - encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int") - encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long") - encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float") - encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double") - - encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal") - type JDecimal = java.math.BigDecimal - // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal") - - encodeDecodeTest("hello", FlatEncoder[String], "string") - encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date") - encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp") - encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary") - - encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int") - encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string") - encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null") - encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int") - encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string") - - encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), - FlatEncoder[Seq[Seq[Int]]], "seq of seq of int") - encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), - FlatEncoder[Seq[Seq[String]]], "seq of seq of string") - - encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int") - encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string") - encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null") - encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int") - encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string") - - encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), - FlatEncoder[Array[Array[Int]]], "array of array of int") - encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), - FlatEncoder[Array[Array[String]]], "array of array of string") - - encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map") - encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") - encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), - FlatEncoder[Map[Int, Map[String, Int]]], "map of map") - - // Kryo encoders - encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") - encodeDecodeTest(new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") - - // Java encoders - encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") - encodeDecodeTest(new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") -} - -/** For testing Kryo serialization based encoder. */ -class KryoSerializable(val value: Int) { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[KryoSerializable].value - } -} - -/** For testing Java serialization based encoder. */ -class JavaSerializable(val value: Int) extends Serializable { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[JavaSerializable].value - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala deleted file mode 100644 index 1798514c5c38..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} - -case class RepeatedStruct(s: Seq[PrimitiveData]) - -case class NestedArray(a: Array[Array[Int]]) { - override def equals(other: Any): Boolean = other match { - case NestedArray(otherArray) => - java.util.Arrays.deepEquals( - a.asInstanceOf[Array[AnyRef]], - otherArray.asInstanceOf[Array[AnyRef]]) - case _ => false - } -} - -case class BoxedData( - intField: java.lang.Integer, - longField: java.lang.Long, - doubleField: java.lang.Double, - floatField: java.lang.Float, - shortField: java.lang.Short, - byteField: java.lang.Byte, - booleanField: java.lang.Boolean) - -case class RepeatedData( - arrayField: Seq[Int], - arrayFieldContainsNull: Seq[java.lang.Integer], - mapField: scala.collection.Map[Int, Long], - mapFieldNull: scala.collection.Map[Int, java.lang.Long], - structField: PrimitiveData) - -case class SpecificCollection(l: List[Int]) - -class ProductEncoderSuite extends ExpressionEncoderSuite { - outers.put(getClass.getName, this) - - case class InnerClass(i: Int) - productTest(InnerClass(1)) - - productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - - productTest( - OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - - productTest(OptionalData(None, None, None, None, None, None, None, None)) - - productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - - productTest(BoxedData(null, null, null, null, null, null, null)) - - productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - - productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest( - RepeatedData( - Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), - Map(1 -> 2L), - Map(1 -> null), - PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) - - productTest(("Seq[(String, String)]", - Seq(("a", "b")))) - productTest(("Seq[(Int, Int)]", - Seq((1, 2)))) - productTest(("Seq[(Long, Long)]", - Seq((1L, 2L)))) - productTest(("Seq[(Float, Float)]", - Seq((1.toFloat, 2.toFloat)))) - productTest(("Seq[(Double, Double)]", - Seq((1.toDouble, 2.toDouble)))) - productTest(("Seq[(Short, Short)]", - Seq((1.toShort, 2.toShort)))) - productTest(("Seq[(Byte, Byte)]", - Seq((1.toByte, 2.toByte)))) - productTest(("Seq[(Boolean, Boolean)]", - Seq((true, false)))) - - productTest(("ArrayBuffer[(String, String)]", - ArrayBuffer(("a", "b")))) - productTest(("ArrayBuffer[(Int, Int)]", - ArrayBuffer((1, 2)))) - productTest(("ArrayBuffer[(Long, Long)]", - ArrayBuffer((1L, 2L)))) - productTest(("ArrayBuffer[(Float, Float)]", - ArrayBuffer((1.toFloat, 2.toFloat)))) - productTest(("ArrayBuffer[(Double, Double)]", - ArrayBuffer((1.toDouble, 2.toDouble)))) - productTest(("ArrayBuffer[(Short, Short)]", - ArrayBuffer((1.toShort, 2.toShort)))) - productTest(("ArrayBuffer[(Byte, Byte)]", - ArrayBuffer((1.toByte, 2.toByte)))) - productTest(("ArrayBuffer[(Boolean, Boolean)]", - ArrayBuffer((true, false)))) - - productTest(("Seq[Seq[(Int, Int)]]", - Seq(Seq((1, 2))))) - - encodeDecodeTest( - 1 -> 10L, - ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]), - "tuple with 2 flat encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]), - "tuple with 2 product encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]), - "tuple with flat encoder and product encoder") - - encodeDecodeTest( - (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), - ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]), - "tuple with product encoder and flat encoder") - - encodeDecodeTest( - (1, (10, 100L)), - { - val intEnc = FlatEncoder[Int] - val longEnc = FlatEncoder[Long] - ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) - }, - "nested tuple encoder") - - private def productTest[T <: Product : TypeTag](input: T): Unit = { - encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7e5acbe8517d..6de3dd626576 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -242,7 +242,7 @@ class GroupedDataset[K, T] private[sql]( * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long])) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) /** * Applies the given function to each cogrouped data. For each unique group, the function will diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 8471eea1b7d9..25ffdcde1771 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -28,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types.StructField import org.apache.spark.unsafe.types.UTF8String @@ -37,16 +34,16 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() - implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int] - implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long] - implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double] - implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float] - implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte] - implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short] - implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean] - implicit def newStringEncoder: Encoder[String] = FlatEncoder[String] + implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() /** * Creates a [[Dataset]] from an RDD. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 95158de710ac..b27b1340cce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.FlatEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint @@ -267,7 +267,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(FlatEncoder[Long]) + count(Column(columnName)).as(ExpressionEncoder[Long]) /** * Aggregate function: returns the number of distinct items in a group. From 4700074530d9a398843e13f0ef514be97a237cea Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 19 Nov 2015 13:08:01 -0800 Subject: [PATCH 1454/1824] [SPARK-11778][SQL] parse table name before it is passed to lookupRelation Fix a bug in DataFrameReader.table (table with schema name such as "db_name.table" doesn't work) Use SqlParser.parseTableIdentifier to parse the table name before lookupRelation. Author: Huaxin Gao Closes #9773 from huaxingao/spark-11778. --- .../scala/org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../spark/sql/hive/HiveDataFrameAnalyticsSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 5872fbded383..dcb3737b70fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -313,7 +313,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName))) + DataFrame(sqlContext, + sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 9864acf76526..f19a74d4b372 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -34,10 +34,14 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with override def beforeAll() { testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c1 int)") } override def afterAll(): Unit = { hiveContext.dropTempTable("mytable") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") } test("rollup") { @@ -74,4 +78,10 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } + + // There was a bug in DataFrameFrameReader.table and it has problem for table with schema name, + // Before fix, it throw Exceptionorg.apache.spark.sql.catalyst.analysis.NoSuchTableException + test("table name with schema") { + hiveContext.read.table("usrdb.test") + } } From 599a8c6e2bf7da70b20ef3046f5ce099dfd637f8 Mon Sep 17 00:00:00 2001 From: David Tolpin Date: Thu, 19 Nov 2015 13:57:23 -0800 Subject: [PATCH 1455/1824] [SPARK-11812][PYSPARK] invFunc=None works properly with python's reduceByKeyAndWindow invFunc is optional and can be None. Instead of invFunc (the parameter) invReduceFunc (a local function) was checked for trueness (that is, not None, in this context). A local function is never None, thus the case of invFunc=None (a common one when inverse reduction is not defined) was treated incorrectly, resulting in loss of data. In addition, the docstring used wrong parameter names, also fixed. Author: David Tolpin Closes #9775 from dtolpin/master. --- python/pyspark/streaming/dstream.py | 6 +++--- python/pyspark/streaming/tests.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 698336cfce18..acec850f02c2 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -524,8 +524,8 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param reduceFunc: associative reduce function - @param invReduceFunc: inverse function of `reduceFunc` + @param func: associative reduce function + @param invFunc: inverse function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @param slideDuration: sliding interval of the window (i.e., the interval after which @@ -556,7 +556,7 @@ def invReduceFunc(t, a, b): if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) - if invReduceFunc: + if invFunc: jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) else: jinvReduceFunc = None diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0bcd1f15532b..3403f6d20d78 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -582,6 +582,17 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + def test_reduce_by_key_and_window_with_none_invFunc(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.map(lambda x: (x, 1))\ + .reduceByKeyAndWindow(operator.add, None, 5, 1)\ + .filter(lambda kv: kv[1] > 0).count() + + expected = [[2], [4], [6], [6], [6], [6]] + self._test_func(input, func, expected) + class StreamingContextTests(PySparkStreamingTestCase): From 014c0f7a9dfdb1686fa9aeacaadb2a17a855a943 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Nov 2015 14:48:18 -0800 Subject: [PATCH 1456/1824] [SPARK-11858][SQL] Move sql.columnar into sql.execution. In addition, tightened visibility of a lot of classes in the columnar package from private[sql] to private[columnar]. Author: Reynold Xin Closes #9842 from rxin/SPARK-11858. --- .../spark/sql/execution/CacheManager.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../columnar/ColumnAccessor.scala | 42 +++++++-------- .../columnar/ColumnBuilder.scala | 51 ++++++++++--------- .../columnar/ColumnStats.scala | 34 ++++++------- .../{ => execution}/columnar/ColumnType.scala | 48 ++++++++--------- .../columnar/GenerateColumnAccessor.scala | 4 +- .../columnar/InMemoryColumnarTableScan.scala | 5 +- .../columnar/NullableColumnAccessor.scala | 4 +- .../columnar/NullableColumnBuilder.scala | 4 +- .../CompressibleColumnAccessor.scala | 6 +-- .../CompressibleColumnBuilder.scala | 6 +-- .../compression/CompressionScheme.scala | 16 +++--- .../compression/compressionSchemes.scala | 16 +++--- .../apache/spark/sql/execution/package.scala | 2 + .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../org/apache/spark/sql/QueryTest.scala | 2 +- .../columnar/ColumnStatsSuite.scala | 6 +-- .../columnar/ColumnTypeSuite.scala | 4 +- .../columnar/ColumnarTestUtils.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../NullableColumnAccessorSuite.scala | 4 +- .../columnar/NullableColumnBuilderSuite.scala | 4 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../compression/BooleanBitSetSuite.scala | 6 +-- .../compression/DictionaryEncodingSuite.scala | 6 +-- .../compression/IntegralDeltaSuite.scala | 6 +-- .../compression/RunLengthEncodingSuite.scala | 6 +-- .../TestCompressibleColumnBuilder.scala | 4 +- .../spark/sql/hive/CachedTableSuite.scala | 2 +- 30 files changed, 155 insertions(+), 147 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnAccessor.scala (75%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnBuilder.scala (74%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnStats.scala (88%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnType.scala (93%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/GenerateColumnAccessor.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/InMemoryColumnarTableScan.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnAccessor.scala (94%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnBuilder.scala (95%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressibleColumnAccessor.scala (84%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressibleColumnBuilder.scala (94%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressionScheme.scala (83%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/compressionSchemes.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnStatsSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnTypeSuite.scala (97%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnarTestUtils.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/InMemoryColumnarQuerySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnAccessorSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnBuilderSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/PartitionBatchPruningSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/BooleanBitSetSuite.scala (94%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/DictionaryEncodingSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/IntegralDeltaSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/RunLengthEncodingSuite.scala (95%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/TestCompressibleColumnBuilder.scala (93%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f85aeb1b0269..293fcfe96e67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3d4ce633c07c..f67c951bc066 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.{Strategy, execution} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala similarity index 75% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 42ec4d3433f1..fee36f602389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} -import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ /** @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods * for primitive values provided by [[MutableRow]]. */ -private[sql] trait ColumnAccessor { +private[columnar] trait ColumnAccessor { initialize() protected def initialize() @@ -41,7 +41,7 @@ private[sql] trait ColumnAccessor { protected def underlyingBuffer: ByteBuffer } -private[sql] abstract class BasicColumnAccessor[JvmType]( +private[columnar] abstract class BasicColumnAccessor[JvmType]( protected val buffer: ByteBuffer, protected val columnType: ColumnType[JvmType]) extends ColumnAccessor { @@ -61,65 +61,65 @@ private[sql] abstract class BasicColumnAccessor[JvmType]( protected def underlyingBuffer = buffer } -private[sql] class NullColumnAccessor(buffer: ByteBuffer) +private[columnar] class NullColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Any](buffer, NULL) with NullableColumnAccessor -private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( +private[columnar] abstract class NativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeColumnType[T]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor with CompressibleColumnAccessor[T] -private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) +private[columnar] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) +private[columnar] class ByteColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BYTE) -private[sql] class ShortColumnAccessor(buffer: ByteBuffer) +private[columnar] class ShortColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, SHORT) -private[sql] class IntColumnAccessor(buffer: ByteBuffer) +private[columnar] class IntColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, INT) -private[sql] class LongColumnAccessor(buffer: ByteBuffer) +private[columnar] class LongColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, LONG) -private[sql] class FloatColumnAccessor(buffer: ByteBuffer) +private[columnar] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) +private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, DOUBLE) -private[sql] class StringColumnAccessor(buffer: ByteBuffer) +private[columnar] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) -private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) +private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) with NullableColumnAccessor -private[sql] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) +private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) -private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) +private[columnar] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType)) with NullableColumnAccessor -private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) +private[columnar] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) with NullableColumnAccessor -private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) +private[columnar] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) with NullableColumnAccessor -private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) +private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor -private[sql] object ColumnAccessor { +private[columnar] object ColumnAccessor { def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val buf = buffer.order(ByteOrder.nativeOrder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala similarity index 74% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 599f30f2d73b..7e26f19bb744 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.ColumnBuilder._ -import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} +import org.apache.spark.sql.execution.columnar.ColumnBuilder._ +import org.apache.spark.sql.execution.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} import org.apache.spark.sql.types._ -private[sql] trait ColumnBuilder { +private[columnar] trait ColumnBuilder { /** * Initializes with an approximate lower bound on the expected number of elements in this column. */ @@ -46,7 +46,7 @@ private[sql] trait ColumnBuilder { def build(): ByteBuffer } -private[sql] class BasicColumnBuilder[JvmType]( +private[columnar] class BasicColumnBuilder[JvmType]( val columnStats: ColumnStats, val columnType: ColumnType[JvmType]) extends ColumnBuilder { @@ -84,17 +84,17 @@ private[sql] class BasicColumnBuilder[JvmType]( } } -private[sql] class NullColumnBuilder +private[columnar] class NullColumnBuilder extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) with NullableColumnBuilder -private[sql] abstract class ComplexColumnBuilder[JvmType]( +private[columnar] abstract class ComplexColumnBuilder[JvmType]( columnStats: ColumnStats, columnType: ColumnType[JvmType]) extends BasicColumnBuilder[JvmType](columnStats, columnType) with NullableColumnBuilder -private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( +private[columnar] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) extends BasicColumnBuilder[T#InternalType](columnStats, columnType) @@ -102,40 +102,45 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) +private[columnar] +class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[columnar] +class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) -private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[columnar] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[columnar] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) -private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) +private[columnar] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[columnar] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) -private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) +private[columnar] +class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +private[columnar] +class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) +private[columnar] +class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) -private[sql] class CompactDecimalColumnBuilder(dataType: DecimalType) +private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType) extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) -private[sql] class DecimalColumnBuilder(dataType: DecimalType) +private[columnar] class DecimalColumnBuilder(dataType: DecimalType) extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) -private[sql] class StructColumnBuilder(dataType: StructType) +private[columnar] class StructColumnBuilder(dataType: StructType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) -private[sql] class ArrayColumnBuilder(dataType: ArrayType) +private[columnar] class ArrayColumnBuilder(dataType: ArrayType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) -private[sql] class MapColumnBuilder(dataType: MapType) +private[columnar] class MapColumnBuilder(dataType: MapType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) -private[sql] object ColumnBuilder { +private[columnar] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 val MAX_BATCH_SIZE_IN_BYTE = 4 * 1024 * 1024L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 91a05650585c..c52ee9ffd6d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { +private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)() val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() @@ -32,7 +32,7 @@ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes) } -private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { +private[columnar] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { val (forAttribute, schema) = { val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) @@ -45,10 +45,10 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` * brings significant performance penalty. */ -private[sql] sealed trait ColumnStats extends Serializable { +private[columnar] sealed trait ColumnStats extends Serializable { protected var count = 0 protected var nullCount = 0 - private[sql] var sizeInBytes = 0L + private[columnar] var sizeInBytes = 0L /** * Gathers statistics information from `row(ordinal)`. @@ -72,14 +72,14 @@ private[sql] sealed trait ColumnStats extends Serializable { /** * A no-op ColumnStats only used for testing purposes. */ -private[sql] class NoopColumnStats extends ColumnStats { +private[columnar] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } -private[sql] class BooleanColumnStats extends ColumnStats { +private[columnar] class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true @@ -97,7 +97,7 @@ private[sql] class BooleanColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ByteColumnStats extends ColumnStats { +private[columnar] class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue @@ -115,7 +115,7 @@ private[sql] class ByteColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ShortColumnStats extends ColumnStats { +private[columnar] class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue @@ -133,7 +133,7 @@ private[sql] class ShortColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class IntColumnStats extends ColumnStats { +private[columnar] class IntColumnStats extends ColumnStats { protected var upper = Int.MinValue protected var lower = Int.MaxValue @@ -151,7 +151,7 @@ private[sql] class IntColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class LongColumnStats extends ColumnStats { +private[columnar] class LongColumnStats extends ColumnStats { protected var upper = Long.MinValue protected var lower = Long.MaxValue @@ -169,7 +169,7 @@ private[sql] class LongColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class FloatColumnStats extends ColumnStats { +private[columnar] class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue @@ -187,7 +187,7 @@ private[sql] class FloatColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class DoubleColumnStats extends ColumnStats { +private[columnar] class DoubleColumnStats extends ColumnStats { protected var upper = Double.MinValue protected var lower = Double.MaxValue @@ -205,7 +205,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class StringColumnStats extends ColumnStats { +private[columnar] class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null @@ -223,7 +223,7 @@ private[sql] class StringColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class BinaryColumnStats extends ColumnStats { +private[columnar] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { @@ -235,7 +235,7 @@ private[sql] class BinaryColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } -private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { +private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) protected var upper: Decimal = null @@ -256,7 +256,7 @@ private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends Column new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ObjectColumnStats(dataType: DataType) extends ColumnStats { +private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats { val columnType = ColumnType(dataType) override def gatherStats(row: InternalRow, ordinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 68e509eb5047..c9f2329db4b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.types.UTF8String * * WARNNING: This only works with HeapByteBuffer */ -object ByteBufferHelper { +private[columnar] object ByteBufferHelper { def getInt(buffer: ByteBuffer): Int = { val pos = buffer.position() buffer.position(pos + 4) @@ -73,7 +73,7 @@ object ByteBufferHelper { * * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] sealed abstract class ColumnType[JvmType] { +private[columnar] sealed abstract class ColumnType[JvmType] { // The catalyst data type of this column. def dataType: DataType @@ -142,7 +142,7 @@ private[sql] sealed abstract class ColumnType[JvmType] { override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[sql] object NULL extends ColumnType[Any] { +private[columnar] object NULL extends ColumnType[Any] { override def dataType: DataType = NullType override def defaultSize: Int = 0 @@ -152,7 +152,7 @@ private[sql] object NULL extends ColumnType[Any] { override def getField(row: InternalRow, ordinal: Int): Any = null } -private[sql] abstract class NativeColumnType[T <: AtomicType]( +private[columnar] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, val defaultSize: Int) extends ColumnType[T#InternalType] { @@ -163,7 +163,7 @@ private[sql] abstract class NativeColumnType[T <: AtomicType]( def scalaTag: TypeTag[dataType.InternalType] = dataType.tag } -private[sql] object INT extends NativeColumnType(IntegerType, 4) { +private[columnar] object INT extends NativeColumnType(IntegerType, 4) { override def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } @@ -192,7 +192,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { } } -private[sql] object LONG extends NativeColumnType(LongType, 8) { +private[columnar] object LONG extends NativeColumnType(LongType, 8) { override def append(v: Long, buffer: ByteBuffer): Unit = { buffer.putLong(v) } @@ -220,7 +220,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) { } } -private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { +private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) { override def append(v: Float, buffer: ByteBuffer): Unit = { buffer.putFloat(v) } @@ -248,7 +248,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { } } -private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { +private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) { override def append(v: Double, buffer: ByteBuffer): Unit = { buffer.putDouble(v) } @@ -276,7 +276,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { } } -private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { +private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) { override def append(v: Boolean, buffer: ByteBuffer): Unit = { buffer.put(if (v) 1: Byte else 0: Byte) } @@ -302,7 +302,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { } } -private[sql] object BYTE extends NativeColumnType(ByteType, 1) { +private[columnar] object BYTE extends NativeColumnType(ByteType, 1) { override def append(v: Byte, buffer: ByteBuffer): Unit = { buffer.put(v) } @@ -330,7 +330,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 1) { } } -private[sql] object SHORT extends NativeColumnType(ShortType, 2) { +private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { override def append(v: Short, buffer: ByteBuffer): Unit = { buffer.putShort(v) } @@ -362,7 +362,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper * objects. */ -private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { +private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { @@ -387,7 +387,7 @@ private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { } } -private[sql] object STRING +private[columnar] object STRING extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { override def actualSize(row: InternalRow, ordinal: Int): Int = { @@ -425,7 +425,7 @@ private[sql] object STRING override def clone(v: UTF8String): UTF8String = v.clone() } -private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) +private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) extends NativeColumnType(DecimalType(precision, scale), 8) { override def extract(buffer: ByteBuffer): Decimal = { @@ -467,13 +467,13 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) } } -private[sql] object COMPACT_DECIMAL { +private[columnar] object COMPACT_DECIMAL { def apply(dt: DecimalType): COMPACT_DECIMAL = { COMPACT_DECIMAL(dt.precision, dt.scale) } } -private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) +private[columnar] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { def serialize(value: JvmType): Array[Byte] @@ -492,7 +492,7 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: } } -private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { +private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def dataType: DataType = BinaryType @@ -512,7 +512,7 @@ private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def deserialize(bytes: Array[Byte]): Array[Byte] = bytes } -private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) +private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int) extends ByteArrayColumnType[Decimal](12) { override val dataType: DataType = DecimalType(precision, scale) @@ -539,13 +539,13 @@ private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) } } -private[sql] object LARGE_DECIMAL { +private[columnar] object LARGE_DECIMAL { def apply(dt: DecimalType): LARGE_DECIMAL = { LARGE_DECIMAL(dt.precision, dt.scale) } } -private[sql] case class STRUCT(dataType: StructType) +private[columnar] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { private val numOfFields: Int = dataType.fields.size @@ -586,7 +586,7 @@ private[sql] case class STRUCT(dataType: StructType) override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) +private[columnar] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { override def defaultSize: Int = 16 @@ -625,7 +625,7 @@ private[sql] case class ARRAY(dataType: ArrayType) override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) +private[columnar] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { override def defaultSize: Int = 32 @@ -663,7 +663,7 @@ private[sql] case class MAP(dataType: MapType) override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() } -private[sql] object ColumnType { +private[columnar] object ColumnType { def apply(dataType: DataType): ColumnType[_] = { dataType match { case NullType => NULL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ff9393b465b7..eaafc96e4d2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -121,7 +121,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; - import org.apache.spark.sql.columnar.MutableUnsafeRow; + import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; public SpecificColumnarIterator generate($exprType[] expr) { return new SpecificColumnarIterator(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index ae77298e6da2..ce701fb3a7f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer @@ -50,7 +50,8 @@ private[sql] object InMemoryRelation { * @param buffers The buffers for serialized columns * @param stats The stat of columns */ -private[sql] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) +private[columnar] +case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 7eaecfe047c3..8d99546924de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.catalyst.expressions.MutableRow -private[sql] trait NullableColumnAccessor extends ColumnAccessor { +private[columnar] trait NullableColumnAccessor extends ColumnAccessor { private var nullsBuffer: ByteBuffer = _ private var nullCount: Int = _ private var seenNulls: Int = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala index 76cfddf1cd01..3a1931bfb5c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow * +---+-----+---------+ * }}} */ -private[sql] trait NullableColumnBuilder extends ColumnBuilder { +private[columnar] trait NullableColumnBuilder extends ColumnBuilder { protected var nulls: ByteBuffer = _ protected var nullCount: Int = _ private var pos: Int = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index cb205defbb1a..6579b5068e65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} import org.apache.spark.sql.types.AtomicType -private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { +private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { this: NativeColumnAccessor[T] => private var decoder: Decoder[T] = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala index 161021ff9615..b0e216feb559 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} +import org.apache.spark.sql.execution.columnar.{ColumnBuilder, NativeColumnBuilder} import org.apache.spark.sql.types.AtomicType /** @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.AtomicType * header body * }}} */ -private[sql] trait CompressibleColumnBuilder[T <: AtomicType] +private[columnar] trait CompressibleColumnBuilder[T <: AtomicType] extends ColumnBuilder with Logging { this: NativeColumnBuilder[T] with WithCompressionSchemes => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala similarity index 83% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 9322b772fd89..920381f9c63d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType -private[sql] trait Encoder[T <: AtomicType] { +private[columnar] trait Encoder[T <: AtomicType] { def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} def compressedSize: Int @@ -37,13 +37,13 @@ private[sql] trait Encoder[T <: AtomicType] { def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: AtomicType] { +private[columnar] trait Decoder[T <: AtomicType] { def next(row: MutableRow, ordinal: Int): Unit def hasNext: Boolean } -private[sql] trait CompressionScheme { +private[columnar] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_]): Boolean @@ -53,15 +53,15 @@ private[sql] trait CompressionScheme { def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } -private[sql] trait WithCompressionSchemes { +private[columnar] trait WithCompressionSchemes { def schemes: Seq[CompressionScheme] } -private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { +private[columnar] trait AllCompressionSchemes extends WithCompressionSchemes { override val schemes: Seq[CompressionScheme] = CompressionScheme.all } -private[sql] object CompressionScheme { +private[columnar] object CompressionScheme { val all: Seq[CompressionScheme] = Seq(PassThrough, RunLengthEncoding, DictionaryEncoding, BooleanBitSet, IntDelta, LongDelta) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 41c9a284e3e4..941f03b745a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer @@ -23,11 +23,11 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types._ -private[sql] case object PassThrough extends CompressionScheme { +private[columnar] case object PassThrough extends CompressionScheme { override val typeId = 0 override def supports(columnType: ColumnType[_]): Boolean = true @@ -64,7 +64,7 @@ private[sql] case object PassThrough extends CompressionScheme { } } -private[sql] case object RunLengthEncoding extends CompressionScheme { +private[columnar] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { @@ -172,7 +172,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } -private[sql] case object DictionaryEncoding extends CompressionScheme { +private[columnar] case object DictionaryEncoding extends CompressionScheme { override val typeId = 2 // 32K unique values allowed @@ -281,7 +281,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } -private[sql] case object BooleanBitSet extends CompressionScheme { +private[columnar] case object BooleanBitSet extends CompressionScheme { override val typeId = 3 val BITS_PER_LONG = 64 @@ -371,7 +371,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { } } -private[sql] case object IntDelta extends CompressionScheme { +private[columnar] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) @@ -451,7 +451,7 @@ private[sql] case object IntDelta extends CompressionScheme { } } -private[sql] case object LongDelta extends CompressionScheme { +private[columnar] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 28fa231e722d..c912734bba9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -19,5 +19,7 @@ package org.apache.spark.sql /** * The physical execution component of Spark SQL. Note that this is a private package. + * All classes in catalyst are considered an internal API to Spark SQL and are subject + * to change between minor releases. */ package object execution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index bce94dafad75..d86df4cfb9b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,7 +27,7 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} import org.apache.spark.storage.{StorageLevel, RDDBlockId} @@ -280,7 +280,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sql("CACHE TABLE testData") sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => - val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum + val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b5417b195f39..6ea1fe4ccfd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation abstract class QueryTest extends PlanTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index 89a664001bdd..b2d04f7c5a6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow @@ -50,7 +50,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -86,7 +86,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = new DecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 63bc39bfa030..34dd96929e6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ import org.apache.spark.{Logging, SparkFunSuite} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index a5882f7870e3..9cae65ef6f5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 6265e40a0a07..25afed25c897 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.sql.{Date, Timestamp} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index aa1605fee8c7..35dc9a276cef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.ByteBuffer @@ -38,7 +38,7 @@ object TestNullableColumnAccessor { } class NullableColumnAccessorSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index 91404577832a..93be3e16a5ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -36,7 +36,7 @@ object TestNullableColumnBuilder { } class NullableColumnBuilderSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 6b7401464f46..d762f7bfe914 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index 9a2948c59ba4..ccbddef0fad3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index acfab6586c0d..830ca0294e1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 2111e9fbe62c..988a577a7b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala similarity index 95% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 67ec08f594a4..ce3affba55c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5268dfe0aa03..5e078f251375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types.AtomicType class TestCompressibleColumnBuilder[T <: AtomicType]( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 5c2fc7d82ffb..99478e82d419 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.execution.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} From 90d384dcbc1d1a3466cf8bae570a26f23012c102 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 19 Nov 2015 14:49:25 -0800 Subject: [PATCH 1457/1824] [SPARK-11831][CORE][TESTS] Use port 0 to avoid port conflicts in tests Use port 0 to fix port-contention-related flakiness Author: Shixiong Zhu Closes #9841 from zsxwing/SPARK-11831. --- .../org/apache/spark/rpc/RpcEnvSuite.scala | 24 +++++++++---------- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 834e4743df86..2f55006420ce 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -39,7 +39,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() - env = createRpcEnv(conf, "local", 12345) + env = createRpcEnv(conf, "local", 0) } override def afterAll(): Unit = { @@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { @@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { @@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") - val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { @@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { @@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") @@ -497,7 +497,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "network-events") @@ -543,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") @@ -571,8 +571,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { @volatile var message: String = null @@ -602,8 +602,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 6478ab51c4da..7aac02775e1b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -40,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { }) val conf = new SparkConf() val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false)) + RpcEnvConfig(conf, "test", "localhost", 0, new SecurityManager(conf), false)) try { val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === @@ -59,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { val conf = SSLSampleConfigs.sparkSSLConfig() val securityManager = new SecurityManager(conf) val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false)) + RpcEnvConfig(conf, "test", "localhost", 0, securityManager, false)) try { val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) From 3bd77b213a9cd177c3ea3c61d37e5098e55f75a5 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Thu, 19 Nov 2015 14:51:40 -0800 Subject: [PATCH 1458/1824] =?UTF-8?q?[SPARK-11799][CORE]=20Make=20it=20exp?= =?UTF-8?q?licit=20in=20executor=20logs=20that=20uncaught=20e=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …xceptions are thrown during executor shutdown This commit will make sure that when uncaught exceptions are prepended with [Container in shutdown] when JVM is shutting down. Author: Srinivasa Reddy Vundela Closes #9809 from vundela/master_11799. --- .../apache/spark/util/SparkUncaughtExceptionHandler.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index 724818724733..5e322557e964 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -29,7 +29,11 @@ private[spark] object SparkUncaughtExceptionHandler override def uncaughtException(thread: Thread, exception: Throwable) { try { - logError("Uncaught exception in thread " + thread, exception) + // Make it explicit that uncaught exceptions are thrown when container is shutting down. + // It will help users when they analyze the executor logs + val inShutdownMsg = if (ShutdownHookManager.inShutdown()) "[Container in shutdown] " else "" + val errMsg = "Uncaught exception in thread " + logError(inShutdownMsg + errMsg + thread, exception) // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) From f7135ed7194d4f936f6f58e14f02b1ed93f68ad1 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Nov 2015 14:53:58 -0800 Subject: [PATCH 1459/1824] [SPARK-11828][CORE] Register DAGScheduler metrics source after app id is known. Author: Marcelo Vanzin Closes #9820 from vanzin/SPARK-11828. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ab374cb71286..af4456c05b0a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -581,6 +581,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(_dagScheduler.metricsSource) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4a9518fff4e7..ae725b467d8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -130,7 +130,7 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) - private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() @@ -1580,8 +1580,6 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread and register the metrics source at the end of the constructor - env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } From 01403aa97b6aaab9b86ae806b5ea9e82690a741f Mon Sep 17 00:00:00 2001 From: hushan Date: Thu, 19 Nov 2015 14:56:00 -0800 Subject: [PATCH 1460/1824] [SPARK-11746][CORE] Use cache-aware method dependencies a small change Author: hushan Closes #9691 from suyanNone/unify-getDependency. --- .../main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index d6a37e8cc5da..0c6ddda52cee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -65,7 +65,7 @@ class PartitionPruningRDD[T: ClassTag]( } override protected def getPartitions: Array[Partition] = - getDependencies.head.asInstanceOf[PruneDependency[T]].partitions + dependencies.head.asInstanceOf[PruneDependency[T]].partitions } From 37cff1b1a79cad11277612cb9bc8bc2365cf5ff2 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 19 Nov 2015 15:11:30 -0800 Subject: [PATCH 1461/1824] [SPARK-11275][SQL] Incorrect results when using rollup/cube Fixes bug with grouping sets (including cube/rollup) where aggregates that included grouping expressions would return the wrong (null) result. Also simplifies the analyzer rule a bit and leaves column pruning to the optimizer. Added multiple unit tests to DataFrameAggregateSuite and verified it passes hive compatibility suite: ``` build/sbt -Phive -Dspark.hive.whitelist='groupby.*_grouping.*' 'test-only org.apache.spark.sql.hive.execution.HiveCompatibilitySuite' ``` This is an alternative to pr https://github.com/apache/spark/pull/9419 but I think its better as it simplifies the analyzer rule instead of adding another special case to it. Author: Andrew Ray Closes #9815 from aray/groupingset-agg-fix. --- .../sql/catalyst/analysis/Analyzer.scala | 58 +++++++---------- .../plans/logical/basicOperators.scala | 4 ++ .../spark/sql/DataFrameAggregateSuite.scala | 62 +++++++++++++++++++ 3 files changed, 90 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 84781cd57f3d..47962ebe6ef8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -213,45 +213,35 @@ class Analyzer( GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - // We will insert another Projection if the GROUP BY keys contains the - // non-attribute expressions. And the top operators can references those - // expressions by its alias. - // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> - // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a - - // find all of the non-attribute expressions in the GROUP BY keys - val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() - - // The pair of (the original GROUP BY key, associated attribute) - val groupByExprPairs = x.groupByExprs.map(_ match { - case e: NamedExpression => (e, e.toAttribute) - case other => { - val alias = Alias(other, other.toString)() - nonAttributeGroupByExpressions += alias // add the non-attributes expression alias - (other, alias.toAttribute) - } - }) - - // substitute the non-attribute expressions for aggregations. - val aggregation = x.aggregations.map(expr => expr.transformDown { - case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) - }.asInstanceOf[NamedExpression]) - // substitute the group by expressions. - val newGroupByExprs = groupByExprPairs.map(_._2) + // Expand works by setting grouping expressions to null as determined by the bitmasks. To + // prevent these null values from being used in an aggregate instead of the original value + // we need to create new aliases for all group by expressions that will only be used for + // the intended purpose. + val groupByAliases: Seq[Alias] = x.groupByExprs.map { + case e: NamedExpression => Alias(e, e.name)() + case other => Alias(other, other.toString)() + } - val child = if (nonAttributeGroupByExpressions.length > 0) { - // insert additional projection if contains the - // non-attribute expressions in the GROUP BY keys - Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) - } else { - x.child + val aggregations: Seq[NamedExpression] = x.aggregations.map { + // If an expression is an aggregate (contains a AggregateExpression) then we dont change + // it so that the aggregation is computed on the unmodified value of its argument + // expressions. + case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr + // If not then its a grouping expression and we need to use the modified (with nulls from + // Expand) value of the expression. + case expr => expr.transformDown { + case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e) + }.asInstanceOf[NamedExpression] } + val child = Project(x.child.output ++ groupByAliases, x.child) + val groupByAttributes = groupByAliases.map(_.toAttribute) + Aggregate( - newGroupByExprs :+ VirtualColumn.groupingIdAttribute, - aggregation, - Expand(x.bitmasks, newGroupByExprs, gid, child)) + groupByAttributes :+ VirtualColumn.groupingIdAttribute, + aggregations, + Expand(x.bitmasks, groupByAttributes, gid, child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 45630a591d34..0c444482c5e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode { override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + // Needs to be unresolved before its translated to Aggregate + Expand because output attributes + // will change in analysis. + override lazy val resolved: Boolean = false + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 71adf2148a40..9c42f65bb6f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("rollup") { + checkAnswer( + courseSales.rollup("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("cube") { + checkAnswer( + courseSales.cube("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("rollup overlapping columns") { + checkAnswer( + testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.rollup("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, null, 9) :: Nil + ) + } + + test("cube overlapping columns") { + checkAnswer( + testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, 1, 3) :: Row(null, 2, 0) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.cube("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, 1, 3) :: Row(null, 2, 6) + :: Row(null, null, 9) :: Nil + ) + } + test("spark.sql.retainGroupColumns config") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), From 880128f37e1bc0b9d98d1786670be62a06c648f2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Nov 2015 16:49:18 -0800 Subject: [PATCH 1462/1824] [SPARK-4134][CORE] Lower severity of some executor loss logs. Don't log ERROR messages when executors are explicitly killed or when the exit reason is not yet known. Author: Marcelo Vanzin Closes #9780 from vanzin/SPARK-11789. --- .../spark/scheduler/ExecutorLossReason.scala | 2 + .../spark/scheduler/TaskSchedulerImpl.scala | 44 ++++++++++++------- .../spark/scheduler/TaskSetManager.scala | 1 + .../CoarseGrainedSchedulerBackend.scala | 18 +++++--- .../spark/deploy/yarn/YarnAllocator.scala | 4 +- 5 files changed, 45 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 47a5cbff4930..7e1197d74280 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -40,6 +40,8 @@ private[spark] object ExecutorExited { } } +private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed by driver.") + /** * A loss reason that means we don't yet know why the executor exited. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bf0419db1f75..bdf19f9f277d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -470,25 +470,25 @@ private[spark] class TaskSchedulerImpl( synchronized { if (executorIdToTaskCount.contains(executorId)) { val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) + logExecutorLoss(executorId, hostPort, reason) removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { - executorIdToHost.get(executorId) match { - case Some(_) => - // If the host mapping still exists, it means we don't know the loss reason for the - // executor. So call removeExecutor() to update tasks running on that executor when - // the real loss reason is finally known. - logError(s"Actual reason for lost executor $executorId: ${reason.message}") - removeExecutor(executorId, reason) - - case None => - // We may get multiple executorLost() calls with different loss reasons. For example, - // one may be triggered by a dropped connection from the slave while another may be a - // report of executor termination from Mesos. We produce log messages for both so we - // eventually report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } + executorIdToHost.get(executorId) match { + case Some(hostPort) => + // If the host mapping still exists, it means we don't know the loss reason for the + // executor. So call removeExecutor() to update tasks running on that executor when + // the real loss reason is finally known. + logExecutorLoss(executorId, hostPort, reason) + removeExecutor(executorId, reason) + + case None => + // We may get multiple executorLost() calls with different loss reasons. For example, + // one may be triggered by a dropped connection from the slave while another may be a + // report of executor termination from Mesos. We produce log messages for both so we + // eventually report the termination reason. + logError(s"Lost an executor $executorId (already removed): $reason") + } } } // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock @@ -498,6 +498,18 @@ private[spark] class TaskSchedulerImpl( } } + private def logExecutorLoss( + executorId: String, + hostPort: String, + reason: ExecutorLossReason): Unit = reason match { + case LossReasonPending => + logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.") + case ExecutorKilled => + logInfo(s"Executor $executorId on $hostPort killed by driver.") + case _ => + logError(s"Lost executor $executorId on $hostPort: $reason") + } + /** * Remove an executor from all our data structures and mark it as lost. If the executor's loss * reason is not yet known, do not yet remove its association with its host nor update the status diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 114468c48c44..a02f3017cb6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -800,6 +800,7 @@ private[spark] class TaskSetManager( for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { val exitCausedByApp: Boolean = reason match { case exited: ExecutorExited => exited.exitCausedByApp + case ExecutorKilled => false case _ => true } handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6f0c910c009a..505c161141c8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -64,8 +64,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val listenerBus = scheduler.sc.listenerBus - // Executors we have requested the cluster manager to kill that have not died yet - private val executorsPendingToRemove = new HashSet[String] + // Executors we have requested the cluster manager to kill that have not died yet; maps + // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't + // be considered an app-related failure). + private val executorsPendingToRemove = new HashMap[String, Boolean] // A map to store hostname with its possible task number running on it protected var hostToLocalTaskCount: Map[String, Int] = Map.empty @@ -250,15 +252,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case Some(executorInfo) => // This must be synchronized because variables mutated // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { + val killed = CoarseGrainedSchedulerBackend.this.synchronized { addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId - executorsPendingToRemove -= executorId executorsPendingLossReason -= executorId + executorsPendingToRemove.remove(executorId).getOrElse(false) } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, reason) + scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => logInfo(s"Asked to remove non-existent executor $executorId") @@ -459,6 +461,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. * + * When asking the executor to be replaced, the executor loss is considered a failure, and + * killed tasks that are running on the executor will count towards the failure limits. If no + * replacement is being requested, then the tasks will not count towards the limit. + * * @param executorIds identifiers of executors to kill * @param replace whether to replace the killed executors with new ones * @param force whether to force kill busy executors @@ -479,7 +485,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorsToKill = knownExecutors .filter { id => !executorsPendingToRemove.contains(id) } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsPendingToRemove ++= executorsToKill + executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7e39c3ea56af..73cd9031f025 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -481,7 +481,7 @@ private[yarn] class YarnAllocator( (true, memLimitExceededLogMessage( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) - case unknown => + case _ => numExecutorsFailed += 1 (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + @@ -493,7 +493,7 @@ private[yarn] class YarnAllocator( } else { logInfo(containerExitReason) } - ExecutorExited(0, exitCausedByApp, containerExitReason) + ExecutorExited(exitStatus, exitCausedByApp, containerExitReason) } else { // If we have already released this container, then it must mean // that the driver has explicitly requested it to be killed From b2cecb80ece59a1c086d4ae7aeebef445a4e7299 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 19 Nov 2015 16:50:08 -0800 Subject: [PATCH 1463/1824] [SPARK-11845][STREAMING][TEST] Added unit test to verify TrackStateRDD is correctly checkpointed To make sure that all lineage is correctly truncated for TrackStateRDD when checkpointed. Author: Tathagata Das Closes #9831 from tdas/SPARK-11845. --- .../org/apache/spark/CheckpointSuite.scala | 411 +++++++++--------- .../streaming/rdd/TrackStateRDDSuite.scala | 60 ++- 2 files changed, 267 insertions(+), 204 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 119e5fc28e41..ab23326c6c25 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,17 +21,223 @@ import java.io.File import scala.reflect.ClassTag +import org.apache.spark.CheckpointSuite._ import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils +trait RDDCheckpointTester { self: SparkFunSuite => + + protected val partitioner = new HashPartitioner(2) + + private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + + /** Implementations of this trait must implement this method */ + protected def sparkContext: SparkContext + + /** + * Test checkpointing of the RDD generated by the given operation. It tests whether the + * serialized size of the RDD is reduce after checkpointing or not. This function should be called + * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDD[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName + val numPartitions = operatedRDD.partitions.length + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + val partitionsBeforeCheckpoint = operatedRDD.partitions + + // Find serialized sizes before and after the checkpoint + logInfo("RDD before checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + checkpoint(operatedRDD, reliableCheckpoint) + val result = collectFunc(operatedRDD) + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the checkpoint file has been created + if (reliableCheckpoint) { + assert( + collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + } + + // Test whether dependencies have been changed from its earlier parent RDD + assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the partitions have been changed from its earlier partitions + assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) + + // Test whether the partitions have been changed to the new Hadoop partitions + assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) + + // Test whether the number of partitions is same as before + assert(operatedRDD.partitions.length === numPartitions) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the RDD has reduced. + logInfo("Size of " + rddType + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, + * the generated RDD will remember the partitions and therefore potentially the whole lineage. + * This function should be called only those RDD whose partitions refer to parent RDD's + * partitions (i.e., do not call it on simple RDD like MappedRDD). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDDPartitions[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDDs = operatedRDD.dependencies.map(_.rdd) + val rddType = operatedRDD.getClass.getSimpleName + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + // Find serialized sizes before and after the checkpoint + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + // checkpoint the parent RDD, not the generated one + parentRDDs.foreach { rdd => + checkpoint(rdd, reliableCheckpoint) + } + val result = collectFunc(operatedRDD) // force checkpointing + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the partitions has reduced + logInfo("Size of partitions of " + rddType + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") + assert( + partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" + ) + } + + /** + * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. + */ + private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + val rddSize = Utils.serialize(rdd).size + val rddCpDataSize = Utils.serialize(rdd.checkpointData).size + val rddPartitionSize = Utils.serialize(rdd.partitions).size + val rddDependenciesSize = Utils.serialize(rdd.dependencies).size + + // Print detailed size, helps in debugging + logInfo("Serialized sizes of " + rdd + + ": RDD = " + rddSize + + ", RDD checkpoint data = " + rddCpDataSize + + ", RDD partitions = " + rddPartitionSize + + ", RDD dependencies = " + rddDependenciesSize + ) + // this makes sure that serializing the RDD's checkpoint data does not + // serialize the whole RDD as well + assert( + rddSize > rddCpDataSize, + "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + + "whole RDD with checkpoint data (" + rddSize + ")" + ) + (rddSize - rddCpDataSize, rddPartitionSize) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + protected def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) + } + + /** + * Recursively force the initialization of the all members of an RDD and it parents. + */ + private def initializeRdd(rdd: RDD[_]): Unit = { + rdd.partitions // forces the initialization of the partitions + rdd.dependencies.map(_.rdd).foreach(initializeRdd) + } + + /** Checkpoint the RDD either locally or reliably. */ + protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { + if (reliableCheckpoint) { + rdd.checkpoint() + } else { + rdd.localCheckpoint() + } + } + + /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ + protected def runTest(name: String)(body: Boolean => Unit): Unit = { + test(name + " [reliable checkpoint]")(body(true)) + test(name + " [local checkpoint]")(body(false)) + } + + /** + * Generate an RDD such that both the RDD and its partitions have large size. + */ + protected def generateFatRDD(): RDD[Int] = { + new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x) + } + + /** + * Generate an pair RDD (with partitioner) such that both the RDD and its partitions + * have large size. + */ + protected def generateFatPairRDD(): RDD[(Int, Int)] = { + new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) + } +} + /** * Test suite for end-to-end checkpointing functionality. * This tests both reliable checkpoints and local checkpoints. */ -class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { +class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalSparkContext { private var checkpointDir: File = _ - private val partitioner = new HashPartitioner(2) override def beforeEach(): Unit = { super.beforeEach() @@ -46,6 +252,8 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging Utils.deleteRecursively(checkpointDir) } + override def sparkContext: SparkContext = sc + runTest("basic checkpointing") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) @@ -250,204 +458,6 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(rdd.isCheckpointedAndMaterialized === true) assert(rdd.partitions.size === 0) } - - // Utility test methods - - /** Checkpoint the RDD either locally or reliably. */ - private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { - if (reliableCheckpoint) { - rdd.checkpoint() - } else { - rdd.localCheckpoint() - } - } - - /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ - private def runTest(name: String)(body: Boolean => Unit): Unit = { - test(name + " [reliable checkpoint]")(body(true)) - test(name + " [local checkpoint]")(body(false)) - } - - private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() - - /** - * Test checkpointing of the RDD generated by the given operation. It tests whether the - * serialized size of the RDD is reduce after checkpointing or not. This function should be called - * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDD[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.headOption.orNull - val rddType = operatedRDD.getClass.getSimpleName - val numPartitions = operatedRDD.partitions.length - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - val partitionsBeforeCheckpoint = operatedRDD.partitions - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - checkpoint(operatedRDD, reliableCheckpoint) - val result = collectFunc(operatedRDD) - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the checkpoint file has been created - if (reliableCheckpoint) { - assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) - } - - // Test whether dependencies have been changed from its earlier parent RDD - assert(operatedRDD.dependencies.head.rdd != parentRDD) - - // Test whether the partitions have been changed from its earlier partitions - assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) - - // Test whether the partitions have been changed to the new Hadoop partitions - assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) - - // Test whether the number of partitions is same as before - assert(operatedRDD.partitions.length === numPartitions) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the RDD has reduced. - logInfo("Size of " + rddType + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing " + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - /** - * Test whether checkpointing of the parent of the generated RDD also - * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent - * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, - * the generated RDD will remember the partitions and therefore potentially the whole lineage. - * This function should be called only those RDD whose partitions refer to parent RDD's - * partitions (i.e., do not call it on simple RDD like MappedRDD). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDDPartitions[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDDs = operatedRDD.dependencies.map(_.rdd) - val rddType = operatedRDD.getClass.getSimpleName - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - // checkpoint the parent RDD, not the generated one - parentRDDs.foreach { rdd => - checkpoint(rdd, reliableCheckpoint) - } - val result = collectFunc(operatedRDD) // force checkpointing - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the partitions has reduced - logInfo("Size of partitions of " + rddType + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") - assert( - partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" - ) - } - - /** - * Generate an RDD such that both the RDD and its partitions have large size. - */ - private def generateFatRDD(): RDD[Int] = { - new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x) - } - - /** - * Generate an pair RDD (with partitioner) such that both the RDD and its partitions - * have large size. - */ - private def generateFatPairRDD(): RDD[(Int, Int)] = { - new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) - } - - /** - * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks - * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. - */ - private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - val rddSize = Utils.serialize(rdd).size - val rddCpDataSize = Utils.serialize(rdd.checkpointData).size - val rddPartitionSize = Utils.serialize(rdd.partitions).size - val rddDependenciesSize = Utils.serialize(rdd.dependencies).size - - // Print detailed size, helps in debugging - logInfo("Serialized sizes of " + rdd + - ": RDD = " + rddSize + - ", RDD checkpoint data = " + rddCpDataSize + - ", RDD partitions = " + rddPartitionSize + - ", RDD dependencies = " + rddDependenciesSize - ) - // this makes sure that serializing the RDD's checkpoint data does not - // serialize the whole RDD as well - assert( - rddSize > rddCpDataSize, - "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + - "whole RDD with checkpoint data (" + rddSize + ")" - ) - (rddSize - rddCpDataSize, rddPartitionSize) - } - - /** - * Serialize and deserialize an object. This is useful to verify the objects - * contents after deserialization (e.g., the contents of an RDD split after - * it is sent to a slave along with a task) - */ - private def serializeDeserialize[T](obj: T): T = { - val bytes = Utils.serialize(obj) - Utils.deserialize[T](bytes) - } - - /** - * Recursively force the initialization of the all members of an RDD and it parents. - */ - private def initializeRdd(rdd: RDD[_]): Unit = { - rdd.partitions // forces the - rdd.dependencies.map(_.rdd).foreach(initializeRdd) - } - } /** RDD partition that has large serialized size. */ @@ -494,5 +504,4 @@ object CheckpointSuite { part ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } - } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 19ef5a14f8ab..0feb3af1abb0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -17,31 +17,40 @@ package org.apache.spark.streaming.rdd +import java.io.File + import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.util.OpenHashMapBasedStateMap -import org.apache.spark.streaming.{Time, State} -import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.streaming.{State, Time} +import org.apache.spark.util.Utils -class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { +class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { private var sc: SparkContext = null + private var checkpointDir: File = _ override def beforeAll(): Unit = { sc = new SparkContext( new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite")) + checkpointDir = Utils.createTempDir() + sc.setCheckpointDir(checkpointDir.toString) } override def afterAll(): Unit = { if (sc != null) { sc.stop() } + Utils.deleteRecursively(checkpointDir) } + override def sparkContext: SparkContext = sc + test("creation from pair RDD") { val data = Seq((1, "1"), (2, "2"), (3, "3")) val partitioner = new HashPartitioner(10) @@ -278,6 +287,51 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { rdd7, Seq(("k3", 2)), Set()) } + test("checkpointing") { + /** + * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs - + * the data RDD and the parent TrackStateRDD. + */ + def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]]) + : Set[(List[(Int, Int, Long)], List[Int])] = { + rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) } + .collect.toSet + } + + /** Generate TrackStateRDD with data RDD having a long lineage */ + def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int]) + : TrackStateRDD[Int, Int, Int, Int] = { + TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) + } + + testRDD( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + + /** Generate TrackStateRDD with parent state RDD having a long lineage */ + def makeStateRDDWithLongLineageParenttateRDD( + longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = { + + // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage + val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) + + // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent + new TrackStateRDD[Int, Int, Int, Int]( + stateRDDWithLongLineage, + stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), + (time: Time, key: Int, value: Option[Int], state: State[Int]) => None, + Time(10), + None + ) + } + + testRDD( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + } + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], From ee21407747fb00db2f26d1119446ccbb20c19232 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 19 Nov 2015 17:14:10 -0800 Subject: [PATCH 1464/1824] [SPARK-11864][SQL] Improve performance of max/min This PR has the following optimization: 1) The greatest/least already does the null-check, so the `If` and `IsNull` are not necessary. 2) In greatest/least, it should initialize the result using the first child (removing one block). 3) For primitive types, the generated greater expression is too complicated (`a > b ? 1 : (a < b) ? -1 : 0) > 0`), should be as simple as `a > b` Combine these optimization, this could improve the performance of `ss_max` query by 30%. Author: Davies Liu Closes #9846 from davies/improve_max. --- .../catalyst/expressions/aggregate/Max.scala | 5 +-- .../catalyst/expressions/aggregate/Min.scala | 5 +-- .../expressions/codegen/CodeGenerator.scala | 12 ++++++ .../expressions/conditionalExpressions.scala | 38 +++++++++++-------- .../expressions/nullExpressions.scala | 10 +++-- 5 files changed, 45 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 61cae44cd0f5..906003188d4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -46,13 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + /* max = */ Greatest(Seq(max, child)) ) override lazy val mergeExpressions: Seq[Expression] = { - val greatest = Greatest(Seq(max.left, max.right)) Seq( - /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + /* max = */ Greatest(Seq(max.left, max.right)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 242456d9e2e1..39f7afbd081c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -47,13 +47,12 @@ case class Min(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + /* min = */ Least(Seq(min, child)) ) override lazy val mergeExpressions: Seq[Expression] = { - val least = Least(Seq(min.left, min.right)) Seq( - /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + /* min = */ Least(Seq(min.left, min.right)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1718cfbd3533..1b7260cdfe51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -329,6 +329,18 @@ class CodeGenContext { throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } + /** + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ + def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { + case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" + case _ => s"(${genComp(dataType, c1, c2)}) > 0" + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 0d4af43978ea..694a2a7c54a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -348,19 +348,22 @@ case class Least(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} < 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, ev.value, eval.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } @@ -403,19 +406,22 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} > 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, eval.value, ev.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 94deafb75b69..df4747d4e6f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -62,11 +62,15 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val first = children(0) + val rest = children.drop(1) + val firstEval = first.gen(ctx) s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${firstEval.code} + boolean ${ev.isNull} = ${firstEval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value}; """ + - children.map { e => + rest.map { e => val eval = e.gen(ctx) s""" if (${ev.isNull}) { From 7ee7d5a3c4ff77d2cee2afce36ff41f6302e6315 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 19 Nov 2015 19:46:10 -0800 Subject: [PATCH 1465/1824] [SPARK-11544][SQL][TEST-HADOOP1.0] sqlContext doesn't use PathFilter Apply the user supplied pathfilter while retrieving the files from fs. Author: Dilip Biswal Closes #9830 from dilipbiswal/spark-11544. --- .../apache/spark/sql/sources/interfaces.scala | 25 ++++++++--- .../datasources/json/JsonSuite.scala | 41 ++++++++++++++++++- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3d3bdf50df6..f9465157c936 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -447,9 +448,15 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -847,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6042b1178aff..ba7718c86463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,19 +19,27 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1390,4 +1398,33 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Setting it twice as the name of the propery has changed between hadoop versions. + hadoopConfiguration.setClass( + "mapred.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } } From 4114ce20fbe820f111e55e891ae3889b0e6e0006 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 19 Nov 2015 22:01:02 -0800 Subject: [PATCH 1466/1824] [SPARK-11846] Add save/load for AFTSurvivalRegression and IsotonicRegression https://issues.apache.org/jira/browse/SPARK-11846 mengxr Author: Xusen Yin Closes #9836 from yinxusen/SPARK-11846. --- .../ml/regression/AFTSurvivalRegression.scala | 78 +++++++++++++++-- .../ml/regression/IsotonicRegression.scala | 83 +++++++++++++++++-- .../AFTSurvivalRegressionSuite.scala | 37 ++++++++- .../regression/IsotonicRegressionSuite.scala | 34 +++++++- 4 files changed, 210 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index b7d095872ffa..aedfb48058dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -21,20 +21,20 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} +import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel +import org.apache.spark.{Logging, SparkException} /** * Params for accelerated failure time (AFT) regression. @@ -120,7 +120,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params @Experimental @Since("1.6.0") class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) - extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams + with DefaultParamsWritable with Logging { @Since("1.6.0") def this() = this(Identifiable.randomUID("aftSurvReg")) @@ -243,6 +244,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) } +@Since("1.6.0") +object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] { + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegression = super.load(path) +} + /** * :: Experimental :: * Model produced by [[AFTSurvivalRegression]]. @@ -254,7 +262,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") val coefficients: Vector, @Since("1.6.0") val intercept: Double, @Since("1.6.0") val scale: Double) - extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable { /** @group setParam */ @Since("1.6.0") @@ -312,6 +320,58 @@ class AFTSurvivalRegressionModel private[ml] ( copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) .setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = + new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this) +} + +@Since("1.6.0") +object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */ + private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter ( + instance: AFTSurvivalRegressionModel + ) extends MLWriter with Logging { + + private case class Data(coefficients: Vector, intercept: Double, scale: Double) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: coefficients, intercept, scale + val data = Data(instance.coefficients, instance.intercept, instance.scale) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[AFTSurvivalRegressionModel].getName + + override def load(path: String): AFTSurvivalRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("coefficients", "intercept", "scale").head() + val coefficients = data.getAs[Vector](0) + val intercept = data.getDouble(1) + val scale = data.getDouble(2) + val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index a1fe01b04710..bbb1c7ac0a51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -17,18 +17,22 @@ package org.apache.spark.ml.regression +import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** @@ -127,7 +131,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures @Since("1.5.0") @Experimental class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Estimator[IsotonicRegressionModel] + with IsotonicRegressionBase with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("isoReg")) @@ -179,6 +184,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri } } +@Since("1.6.0") +object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] { + + @Since("1.6.0") + override def load(path: String): IsotonicRegression = super.load(path) +} + /** * :: Experimental :: * Model fitted by IsotonicRegression. @@ -194,7 +206,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri class IsotonicRegressionModel private[ml] ( override val uid: String, private val oldModel: MLlibIsotonicRegressionModel) - extends Model[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable { /** @group setParam */ @Since("1.5.0") @@ -240,4 +252,61 @@ class IsotonicRegressionModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false) } + + @Since("1.6.0") + override def write: MLWriter = + new IsotonicRegressionModelWriter(this) +} + +@Since("1.6.0") +object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader + + @Since("1.6.0") + override def load(path: String): IsotonicRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[IsotonicRegressionModel]] */ + private[IsotonicRegressionModel] class IsotonicRegressionModelWriter ( + instance: IsotonicRegressionModel + ) extends MLWriter with Logging { + + private case class Data( + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: boundaries, predictions, isotonic + val data = Data( + instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IsotonicRegressionModelReader extends MLReader[IsotonicRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[IsotonicRegressionModel].getName + + override def load(path: String): IsotonicRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("boundaries", "predictions", "isotonic").head() + val boundaries = data.getAs[Seq[Double]](0).toArray + val predictions = data.getAs[Seq[Double]](1).toArray + val isotonic = data.getBoolean(2) + val model = new IsotonicRegressionModel( + metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 359f31027172..d718ef63b531 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -21,14 +21,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} -class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class AFTSurvivalRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ @@ -332,4 +333,32 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex assert(prediction ~== model.predict(features) relTol 1E-5) } } + + test("read/write") { + def checkModelData( + model: AFTSurvivalRegressionModel, + model2: AFTSurvivalRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + assert(model.scale === model2.scale) + } + val aft = new AFTSurvivalRegression() + testEstimatorAndModelReadWrite(aft, datasetMultivariate, + AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + } +} + +object AFTSurvivalRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "fitIntercept" -> true, + "maxIter" -> 2, + "tol" -> 0.01 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 59f4193abc8f..f067c29d27a7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class IsotonicRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { sqlContext.createDataFrame( labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } @@ -164,4 +166,32 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) } + + test("read/write") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + + def checkModelData(model: IsotonicRegressionModel, model2: IsotonicRegressionModel): Unit = { + assert(model.boundaries === model2.boundaries) + assert(model.predictions === model2.predictions) + assert(model.isotonic === model2.isotonic) + } + + val ir = new IsotonicRegression() + testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, + checkModelData) + } +} + +object IsotonicRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "isotonic" -> true, + "featureIndex" -> 0 + ) } From 3b7f056da87a23f3a96f0311b3a947a9b698f38b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 Nov 2015 22:02:17 -0800 Subject: [PATCH 1467/1824] [SPARK-11829][ML] Add read/write to estimators under ml.feature (II) Add read/write support to the following estimators under spark.ml: * ChiSqSelector * PCA * VectorIndexer * Word2Vec Author: Yanbo Liang Closes #9838 from yanboliang/spark-11829. --- .../spark/ml/feature/ChiSqSelector.scala | 65 ++++++++++++++++-- .../org/apache/spark/ml/feature/PCA.scala | 67 +++++++++++++++++-- .../spark/ml/feature/VectorIndexer.scala | 66 ++++++++++++++++-- .../apache/spark/ml/feature/Word2Vec.scala | 67 +++++++++++++++++-- .../apache/spark/mllib/feature/Word2Vec.scala | 6 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 22 +++++- .../apache/spark/ml/feature/PCASuite.scala | 26 ++++++- .../spark/ml/feature/VectorIndexerSuite.scala | 22 +++++- .../spark/ml/feature/Word2VecSuite.scala | 30 ++++++++- 9 files changed, 338 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 5e4061fba549..dfec03828f4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.{AttributeGroup, _} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint @@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params */ @Experimental final class ChiSqSelector(override val uid: String) - extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("chiSqSelector")) @@ -95,6 +96,13 @@ final class ChiSqSelector(override val uid: String) override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra) } +@Since("1.6.0") +object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] { + + @Since("1.6.0") + override def load(path: String): ChiSqSelector = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[ChiSqSelector]]. @@ -103,7 +111,12 @@ final class ChiSqSelector(override val uid: String) final class ChiSqSelectorModel private[ml] ( override val uid: String, private val chiSqSelector: feature.ChiSqSelectorModel) - extends Model[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable { + + import ChiSqSelectorModel._ + + /** list of indices to select (filter). Must be ordered asc */ + val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures /** @group setParam */ def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -147,4 +160,46 @@ final class ChiSqSelectorModel private[ml] ( val copied = new ChiSqSelectorModel(uid, chiSqSelector) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new ChiSqSelectorModelWriter(this) +} + +@Since("1.6.0") +object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { + + private[ChiSqSelectorModel] + class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter { + + private case class Data(selectedFeatures: Seq[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.selectedFeatures.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] { + + private val className = classOf[ChiSqSelectorModel].getName + + override def load(path: String): ChiSqSelectorModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head() + val selectedFeatures = data.getAs[Seq[Int]](0).toArray + val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) + val model = new ChiSqSelectorModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader + + @Since("1.6.0") + override def load(path: String): ChiSqSelectorModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 539084704b65..32d7afee6e73 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -17,13 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC * PCA trains a model to project vectors to a low-dimensional space using PCA. */ @Experimental -class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("pca")) @@ -86,6 +89,13 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams override def copy(extra: ParamMap): PCA = defaultCopy(extra) } +@Since("1.6.0") +object PCA extends DefaultParamsReadable[PCA] { + + @Since("1.6.0") + override def load(path: String): PCA = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[PCA]]. @@ -94,7 +104,12 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams class PCAModel private[ml] ( override val uid: String, pcaModel: feature.PCAModel) - extends Model[PCAModel] with PCAParams { + extends Model[PCAModel] with PCAParams with MLWritable { + + import PCAModel._ + + /** a principal components Matrix. Each column is one principal component. */ + val pc: DenseMatrix = pcaModel.pc /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -127,4 +142,46 @@ class PCAModel private[ml] ( val copied = new PCAModel(uid, pcaModel) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new PCAModelWriter(this) +} + +@Since("1.6.0") +object PCAModel extends MLReadable[PCAModel] { + + private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { + + private case class Data(k: Int, pc: DenseMatrix) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.getK, instance.pc) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class PCAModelReader extends MLReader[PCAModel] { + + private val className = classOf[PCAModel].getName + + override def load(path: String): PCAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath) + .select("k", "pc") + .head() + val oldModel = new feature.PCAModel(k, pc) + val model = new PCAModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[PCAModel] = new PCAModelReader + + @Since("1.6.0") + override def load(path: String): PCAModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 52e0599e38d8..a637a6f2881d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,12 +22,14 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.udf @@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu */ @Experimental class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] - with VectorIndexerParams { + with VectorIndexerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecIdx")) @@ -136,7 +138,11 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } -private object VectorIndexer { +@Since("1.6.0") +object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { + + @Since("1.6.0") + override def load(path: String): VectorIndexer = super.load(path) /** * Helper class for tracking unique values for each feature. @@ -146,7 +152,7 @@ private object VectorIndexer { * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. * @param maxCategories This class caps the number of unique values collected at maxCategories. */ - class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) extends Serializable { /** featureValueSets[feature index] = set of unique values */ @@ -252,7 +258,9 @@ class VectorIndexerModel private[ml] ( override val uid: String, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) - extends Model[VectorIndexerModel] with VectorIndexerParams { + extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable { + + import VectorIndexerModel._ /** Java-friendly version of [[categoryMaps]] */ def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { @@ -408,4 +416,48 @@ class VectorIndexerModel private[ml] ( val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new VectorIndexerModelWriter(this) +} + +@Since("1.6.0") +object VectorIndexerModel extends MLReadable[VectorIndexerModel] { + + private[VectorIndexerModel] + class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter { + + private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.numFeatures, instance.categoryMaps) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] { + + private val className = classOf[VectorIndexerModel].getName + + override def load(path: String): VectorIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("numFeatures", "categoryMaps") + .head() + val numFeatures = data.getAs[Int](0) + val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) + val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader + + @Since("1.6.0") + override def load(path: String): VectorIndexerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 708dbeef84db..a8d61b6dea00 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,15 +17,17 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -90,7 +92,8 @@ private[feature] trait Word2VecBase extends Params * natural language processing or machine learning process. */ @Experimental -final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("w2v")) @@ -139,6 +142,13 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } +@Since("1.6.0") +object Word2Vec extends DefaultParamsReadable[Word2Vec] { + + @Since("1.6.0") + override def load(path: String): Word2Vec = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[Word2Vec]]. @@ -147,7 +157,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] class Word2VecModel private[ml] ( override val uid: String, @transient private val wordVectors: feature.Word2VecModel) - extends Model[Word2VecModel] with Word2VecBase { + extends Model[Word2VecModel] with Word2VecBase with MLWritable { + + import Word2VecModel._ /** * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and @@ -224,4 +236,49 @@ class Word2VecModel private[ml] ( val copied = new Word2VecModel(uid, wordVectors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new Word2VecModelWriter(this) +} + +@Since("1.6.0") +object Word2VecModel extends MLReadable[Word2VecModel] { + + private[Word2VecModel] + class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { + + private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class Word2VecModelReader extends MLReader[Word2VecModel] { + + private val className = classOf[Word2VecModel].getName + + override def load(path: String): Word2VecModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("wordIndex", "wordVectors") + .head() + val wordIndex = data.getAs[Map[String, Int]](0) + val wordVectors = data.getAs[Seq[Float]](1).toArray + val oldModel = new feature.Word2VecModel(wordIndex, wordVectors) + val model = new Word2VecModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[Word2VecModel] = new Word2VecModelReader + + @Since("1.6.0") + override def load(path: String): Word2VecModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 7ab0d89d23a3..a47f27b0afb1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -432,9 +432,9 @@ class Word2Vec extends Serializable with Logging { * (i * vectorSize, i * vectorSize + vectorSize) */ @Since("1.1.0") -class Word2VecModel private[mllib] ( - private val wordIndex: Map[String, Int], - private val wordVectors: Array[Float]) extends Serializable with Saveable { +class Word2VecModel private[spark] ( + private[spark] val wordIndex: Map[String, Int], + private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable { private val numWords = wordIndex.size // vectorSize: Dimension of each word's vector. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index e5a42967bd2c..7827db2794cf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + test("Test Chi-Square selector") { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ @@ -58,4 +62,20 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(vec1 ~== vec2 absTol 1e-1) } } + + test("ChiSqSelector read/write") { + val t = new ChiSqSelector() + .setFeaturesCol("myFeaturesCol") + .setLabelCol("myLabelCol") + .setOutputCol("myOutputCol") + .setNumTopFeatures(2) + testDefaultReadWrite(t) + } + + test("ChiSqSelectorModel read/write") { + val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) + val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.selectedFeatures === instance.selectedFeatures) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 30c500f87a76..5a21cd20ceed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PCA) @@ -65,4 +65,24 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("read/write") { + + def checkModelData(model1: PCAModel, model2: PCAModel): Unit = { + assert(model1.pc === model2.pc) + } + val allParams: Map[String, Any] = Map( + "k" -> 3, + "inputCol" -> "features", + "outputCol" -> "pca_features" + ) + val data = Seq( + (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))), + (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + ) + val df = sqlContext.createDataFrame(data).toDF("id", "features") + val pca = new PCA().setK(3) + testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 8cb0a2cf14d3..67817fa4baf5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,13 +22,14 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest with Logging { import VectorIndexerSuite.FeatureData @@ -251,6 +252,23 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L } } } + + test("VectorIndexer read/write") { + val t = new VectorIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxCategories(30) + testDefaultReadWrite(t) + } + + test("VectorIndexerModel read/write") { + val categoryMaps = Map(0 -> Map(0.0 -> 0, 1.0 -> 1), 1 -> Map(0.0 -> 0, 1.0 -> 1, + 2.0 -> 2, 3.0 -> 3), 2 -> Map(0.0 -> 0, -1.0 -> 1, 2.0 -> 2)) + val instance = new VectorIndexerModel("myVectorIndexerModel", 3, categoryMaps) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.numFeatures === instance.numFeatures) + assert(newInstance.categoryMaps === instance.categoryMaps) + } } private[feature] object VectorIndexerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 23dfdaa9f8fc..a773244cd735 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -143,5 +143,31 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + + test("Word2Vec read/write") { + val t = new Word2Vec() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxIter(2) + .setMinCount(8) + .setNumPartitions(1) + .setSeed(42L) + .setStepSize(0.01) + .setVectorSize(100) + testDefaultReadWrite(t) + } + + test("Word2VecModel read/write") { + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val oldModel = new OldWord2VecModel(word2VecMap) + val instance = new Word2VecModel("myWord2VecModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + } } From 7216f405454f6f3557b5b1f72df8f393605faf60 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 Nov 2015 22:14:01 -0800 Subject: [PATCH 1468/1824] [SPARK-11875][ML][PYSPARK] Update doc for PySpark HasCheckpointInterval * Update doc for PySpark ```HasCheckpointInterval``` that users can understand how to disable checkpoint. * Update doc for PySpark ```cacheNodeIds``` of ```DecisionTreeParams``` to notify the relationship between ```cacheNodeIds``` and ```checkpointInterval```. Author: Yanbo Liang Closes #9856 from yanboliang/spark-11875. --- python/pyspark/ml/param/_shared_params_code_gen.py | 6 ++++-- python/pyspark/ml/param/shared.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 070c5db01ae7..0528dc1e3a6b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -118,7 +118,8 @@ def get$Name(self): ("inputCols", "input column names.", None), ("outputCol", "output column name.", "self.uid + '__output'"), ("numFeatures", "number of features.", None), - ("checkpointInterval", "checkpoint interval (>= 1).", None), + ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None), ("seed", "random seed.", "hash(type(self).__name__)"), ("tol", "the convergence tolerance for iterative algorithms.", None), ("stepSize", "Step size to be used for each iteration of optimization.", None), @@ -157,7 +158,8 @@ def get$Name(self): ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."), ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + - "Caching can speed up training of deeper trees.")] + "Caching can speed up training of deeper trees. Users can set how often should the " + + "cache be checkpointed or disable it by setting checkpointInterval.")] decisionTreeCode = '''class DecisionTreeParams(Params): """ diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4bdf2a8cc563..4d960801502c 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -325,16 +325,16 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: checkpoint interval (>= 1). + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. """ # a placeholder to make it appear in the generated doc - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1).") + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def __init__(self): super(HasCheckpointInterval, self).__init__() - #: param for checkpoint interval (>= 1). - self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1).") + #: param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def setCheckpointInterval(self, value): """ @@ -636,7 +636,7 @@ class DecisionTreeParams(Params): minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") def __init__(self): @@ -651,8 +651,8 @@ def __init__(self): self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") #: param for Maximum memory in MB allocated to histogram aggregation. self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. - self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. + self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") def setMaxDepth(self, value): """ From 0fff8eb3e476165461658d4e16682ec64269fdfe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 19 Nov 2015 23:42:24 -0800 Subject: [PATCH 1469/1824] [SPARK-11869][ML] Clean up TempDirectory properly in ML tests Need to remove parent directory (```className```) rather than just tempDir (```className/random_name```) I tested this with IDFSuite, which has 2 read/write tests, and it fixes the problem. CC: mengxr Can you confirm this is fine? I believe it is since the same ```random_name``` is used for all tests in a suite; we basically have an extra unneeded level of nesting. Author: Joseph K. Bradley Closes #9851 from jkbradley/tempdir-cleanup. --- .../src/test/scala/org/apache/spark/ml/util/TempDirectory.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 2742026a69c2..c8a0bb16247b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -35,7 +35,7 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => override def beforeAll(): Unit = { super.beforeAll() - _tempDir = Utils.createTempDir(this.getClass.getName) + _tempDir = Utils.createTempDir(namePrefix = this.getClass.getName) } override def afterAll(): Unit = { From 3e1d120cedb4bd9e1595e95d4d531cf61da6684d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 19 Nov 2015 23:43:18 -0800 Subject: [PATCH 1470/1824] [SPARK-11867] Add save/load for kmeans and naive bayes https://issues.apache.org/jira/browse/SPARK-11867 Author: Xusen Yin Closes #9849 from yinxusen/SPARK-11867. --- .../spark/ml/classification/NaiveBayes.scala | 68 +++++++++++++++++-- .../apache/spark/ml/clustering/KMeans.scala | 67 ++++++++++++++++-- .../ml/classification/NaiveBayesSuite.scala | 47 +++++++++++-- .../spark/ml/clustering/KMeansSuite.scala | 41 ++++++++--- 4 files changed, 195 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index a14dcecbaf5b..c512a2cb8bf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -17,12 +17,15 @@ package org.apache.spark.ml.classification +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { @Experimental class NaiveBayes(override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] - with NaiveBayesParams { + with NaiveBayesParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("nb")) @@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String) override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } +@Since("1.6.0") +object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + + @Since("1.6.0") + override def load(path: String): NaiveBayes = super.load(path) +} + /** * :: Experimental :: * Model produced by [[NaiveBayes]] @@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] + with NaiveBayesParams with MLWritable { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] ( s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } + @Since("1.6.0") + override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this) } -private[ml] object NaiveBayesModel { +@Since("1.6.0") +object NaiveBayesModel extends MLReadable[NaiveBayesModel] { /** Convert a model from the old API */ - def fromOld( + private[ml] def fromOld( oldModel: OldNaiveBayesModel, parent: NaiveBayes): NaiveBayesModel = { val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") @@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel { oldModel.theta.flatten, true) new NaiveBayesModel(uid, pi, theta) } + + @Since("1.6.0") + override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader + + @Since("1.6.0") + override def load(path: String): NaiveBayesModel = super.load(path) + + /** [[MLWriter]] instance for [[NaiveBayesModel]] */ + private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter { + + private case class Data(pi: Vector, theta: Matrix) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: pi, theta + val data = Data(instance.pi, instance.theta) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[NaiveBayesModel].getName + + override def load(path: String): NaiveBayesModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) + val model = new NaiveBayesModel(metadata.uid, pi, theta) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 509be6300239..71e968497500 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,10 +17,12 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.util._ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} - /** * Common params for KMeans and KMeansModel */ @@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Experimental class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + private val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with MLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -129,6 +131,52 @@ class KMeansModel private[ml] ( val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + @Since("1.6.0") + override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) +} + +@Since("1.6.0") +object KMeansModel extends MLReadable[KMeansModel] { + + @Since("1.6.0") + override def read: MLReader[KMeansModel] = new KMeansModelReader + + @Since("1.6.0") + override def load(path: String): KMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[KMeansModel]] */ + private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + + private case class Data(clusterCenters: Array[Vector]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data = Data(instance.clusterCenters) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class KMeansModelReader extends MLReader[KMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[KMeansModel].getName + + override def load(path: String): KMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() + val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -141,7 +189,7 @@ class KMeansModel private[ml] ( @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams { + extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { setDefault( k -> 2, @@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") ( } } +@Since("1.6.0") +object KMeans extends DefaultParamsReadable[KMeans] { + + @Since("1.6.0") + override def load(path: String): KMeans = super.load(path) +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 98bc9511163e..082a6bcd211a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} +import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.classification.NaiveBayesSuite._ -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} + +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + } def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { @@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "bernoulli") } + + test("read/write") { + def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = { + assert(model.pi === model2.pi) + assert(model.theta === model2.theta) + } + val nb = new NaiveBayes() + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + } +} + +object NaiveBayesSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "smoothing" -> 0.1 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c05f90550d16..2724e51f31aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} private[clustering] case class TestRow(features: Vector) -object KMeansSuite { - def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { - val sc = sql.sparkContext - val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) - .map(v => new TestRow(v)) - sql.createDataFrame(rdd) - } -} - -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) } + + test("read/write") { + def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val kmeans = new KMeans() + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + } +} + +object KMeansSuite { + def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = sql.sparkContext + val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) + .map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01 + ) } From a66142decee48bf5689fb7f4f33646d7bb1ac08d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 00:46:29 -0800 Subject: [PATCH 1471/1824] [SPARK-11877] Prevent agg. fallback conf. from leaking across test suites This patch fixes an issue where the `spark.sql.TungstenAggregate.testFallbackStartsAt` SQLConf setting was not properly reset / cleared at the end of `TungstenAggregationQueryWithControlledFallbackSuite`. This ended up causing test failures in HiveCompatibilitySuite in Maven builds by causing spilling to occur way too frequently. This configuration leak was inadvertently introduced during test cleanup in #9618. Author: Josh Rosen Closes #9857 from JoshRosen/clear-fallback-prop-in-test-teardown. --- .../execution/AggregationQuerySuite.scala | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6dde79f74d3d..39c0a2a0de04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -868,29 +868,27 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => - sqlContext.setConf( - "spark.sql.TungstenAggregate.testFallbackStartsAt", - fallbackStartsAt.toString) - - // Create a new df to make sure its physical operator picks up - // spark.sql.TungstenAggregate.testFallbackStartsAt. - // todo: remove it? - val newActual = DataFrame(sqlContext, actual.logicalPlan) - - QueryTest.checkAnswer(newActual, expectedAnswer) match { - case Some(errorMessage) => - val newErrorMessage = - s""" - |The following aggregation query failed when using TungstenAggregate with - |controlled fallback (it falls back to sort-based aggregation once it has processed - |$fallbackStartsAt input rows). The query is - |${actual.queryExecution} - | - |$errorMessage - """.stripMargin - - fail(newErrorMessage) - case None => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> fallbackStartsAt.toString) { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } } } } From 9ace2e5c8d7fbd360a93bc5fc4eace64a697b44f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 20 Nov 2015 09:55:53 -0800 Subject: [PATCH 1472/1824] [SPARK-11852][ML] StandardScaler minor refactor ```withStd``` and ```withMean``` should be params of ```StandardScaler``` and ```StandardScalerModel```. Author: Yanbo Liang Closes #9839 from yanboliang/standardScaler-refactor. --- .../spark/ml/feature/StandardScaler.scala | 60 +++++++++---------- .../ml/feature/StandardScalerSuite.scala | 11 ++-- 2 files changed, 32 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 6d545219ebf4..d76a9c6275e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType} private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { /** - * Centers the data with mean before scaling. + * Whether to center the data with mean before scaling. * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ - val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + val withMean: BooleanParam = new BooleanParam(this, "withMean", + "Whether to center data with mean") + + /** @group getParam */ + def getWithMean: Boolean = $(withMean) /** - * Scales the data to unit standard deviation. + * Whether to scale the data to unit standard deviation. * Default: true * @group param */ - val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") + val withStd: BooleanParam = new BooleanParam(this, "withStd", + "Whether to scale the data to unit standard deviation") + + /** @group getParam */ + def getWithStd: Boolean = $(withStd) + + setDefault(withMean -> false, withStd -> true) } /** @@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM def this() = this(Identifiable.randomUID("stdScal")) - setDefault(withMean -> false, withStd -> true) - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) - copyValues(new StandardScalerModel(uid, scalerModel).setParent(this)) + copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] { /** * :: Experimental :: * Model fitted by [[StandardScaler]]. + * + * @param std Standard deviation of the StandardScalerModel + * @param mean Mean of the StandardScalerModel */ @Experimental class StandardScalerModel private[ml] ( override val uid: String, - scaler: feature.StandardScalerModel) + val std: Vector, + val mean: Vector) extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ - /** Standard deviation of the StandardScalerModel */ - val std: Vector = scaler.std - - /** Mean of the StandardScalerModel */ - val mean: Vector = scaler.mean - - /** Whether to scale to unit standard deviation. */ - @Since("1.6.0") - def getWithStd: Boolean = scaler.withStd - - /** Whether to center data with mean. */ - @Since("1.6.0") - def getWithMean: Boolean = scaler.withMean - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -139,6 +137,7 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } dataset.withColumn($(outputCol), scale(col($(inputCol)))) } @@ -154,7 +153,7 @@ class StandardScalerModel private[ml] ( } override def copy(extra: ParamMap): StandardScalerModel = { - val copied = new StandardScalerModel(uid, scaler) + val copied = new StandardScalerModel(uid, std, mean) copyValues(copied, extra).setParent(parent) } @@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { - private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) + private case class Data(std: Vector, mean: Vector) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean) + val data = Data(instance.std, instance.mean) val dataPath = new Path(path, "data").toString sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) = - sqlContext.read.parquet(dataPath) - .select("std", "mean", "withStd", "withMean") - .head() - // This is very likely to change in the future because withStd and withMean should be params. - val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean) - val model = new StandardScalerModel(metadata.uid, oldModel) + val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath) + .select("std", "mean") + .head() + val model = new StandardScalerModel(metadata.uid, std, mean) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 49a4b2efe0c2..1eae125a524e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext test("params") { ParamsSuite.checkParams(new StandardScaler) - val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0)) - ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel)) + ParamsSuite.checkParams(new StandardScalerModel("empty", + Vectors.dense(1.0), Vectors.dense(2.0))) } test("Standardization with default parameter") { @@ -126,13 +126,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("StandardScalerModel read/write") { - val oldModel = new feature.StandardScalerModel( - Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true) - val instance = new StandardScalerModel("myStandardScalerModel", oldModel) + val instance = new StandardScalerModel("myStandardScalerModel", + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0)) val newInstance = testDefaultReadWrite(instance) assert(newInstance.std === instance.std) assert(newInstance.mean === instance.mean) - assert(newInstance.getWithStd === instance.getWithStd) - assert(newInstance.getWithMean === instance.getWithMean) } } From e359d5dcf5bd300213054ebeae9fe75c4f7eb9e7 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 20 Nov 2015 09:57:09 -0800 Subject: [PATCH 1473/1824] [SPARK-11689][ML] Add user guide and example code for LDA under spark.ml jira: https://issues.apache.org/jira/browse/SPARK-11689 Add simple user guide for LDA under spark.ml and example code under examples/. Use include_example to include example code in the user guide markdown. Check SPARK-11606 for instructions. Author: Yuhao Yang Closes #9722 from hhbyyh/ldaMLExample. --- docs/ml-clustering.md | 30 ++++++ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 + .../spark/examples/ml/JavaLDAExample.java | 94 +++++++++++++++++++ .../apache/spark/examples/ml/LDAExample.scala | 77 +++++++++++++++ 5 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 docs/ml-clustering.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md new file mode 100644 index 000000000000..1743ef43a6dd --- /dev/null +++ b/docs/ml-clustering.md @@ -0,0 +1,30 @@ +--- +layout: global +title: Clustering - ML +displayTitle: ML - Clustering +--- + +In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). + +## Latent Dirichlet allocation (LDA) + +`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, +and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +`EMLDAOptimizer` to a `DistributedLDAModel` if needed. + +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. + +
    +{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} +
    + +
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be18a05361a1..6f35b30c3d4d 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,6 +40,7 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -950,4 +951,4 @@ model.transform(test) {% endhighlight %}
    -
    +
    \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 91e50ccfecec..54e35fcbb15a 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,6 +69,7 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java new file mode 100644 index 000000000000..b3a7d2eb2978 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * An example demonstrating LDA + * Run with + *
    + * bin/run-example ml.JavaLDAExample
    + * 
    + */ +public class JavaLDAExample { + + private static class ParseVector implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + + String inputFile = "data/mllib/sample_lda_data.txt"; + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a LDA model + LDA lda = new LDA() + .setK(10) + .setMaxIter(10); + LDAModel model = lda.fit(dataset); + + System.out.println(model.logLikelihood(dataset)); + System.out.println(model.logPerplexity(dataset)); + + // Shows the result + DataFrame topics = model.describeTopics(3); + topics.show(false); + model.transform(dataset).show(false); + + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala new file mode 100644 index 000000000000..419ce3d87a6a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +// $example on$ +import org.apache.spark.ml.clustering.LDA +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} +// $example off$ + +/** + * An example demonstrating a LDA of ML pipeline. + * Run with + * {{{ + * bin/run-example ml.LDAExample + * }}} + */ +object LDAExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + + val input = "data/mllib/sample_lda_data.txt" + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a LDA model + val lda = new LDA() + .setK(10) + .setMaxIter(10) + .setFeaturesCol(FEATURES_COL) + val model = lda.fit(dataset) + val transformed = model.transform(dataset) + + val ll = model.logLikelihood(dataset) + val lp = model.logPerplexity(dataset) + + // describeTopics + val topics = model.describeTopics(3) + + // Shows the result + topics.show(false) + transformed.show(false) + + // $example off$ + sc.stop() + } +} +// scalastyle:on println From bef361c589c0a38740232fd8d0a45841e4fc969a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 11:20:47 -0800 Subject: [PATCH 1474/1824] [SPARK-11876][SQL] Support printSchema in DataSet API DataSet APIs look great! However, I am lost when doing multiple level joins. For example, ``` val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a") val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2").printSchema() ``` The printed schema is like ``` root |-- _1: struct (nullable = true) | |-- _1: struct (nullable = true) | | |-- _1: string (nullable = true) | | |-- _2: integer (nullable = true) | |-- _2: struct (nullable = true) | | |-- _1: string (nullable = true) | | |-- _2: integer (nullable = true) |-- _2: struct (nullable = true) | |-- _1: string (nullable = true) | |-- _2: integer (nullable = true) ``` Personally, I think we need the printSchema function. Sometimes, I do not know how to specify the column, especially when their data types are mixed. For example, if I want to write the following select for the above multi-level join, I have to know the schema: ``` newDS.select(expr("_1._2._2 + 1").as[Int]).collect() ``` marmbrus rxin cloud-fan Do you have the same feeling? Author: gatorsmile Closes #9855 from gatorsmile/printSchemaDataSet. --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 9 --------- .../scala/org/apache/spark/sql/execution/Queryable.scala | 9 +++++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 98358127e270..7abcecaa2880 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -299,15 +299,6 @@ class DataFrame private[sql]( */ def columns: Array[String] = schema.fields.map(_.name) - /** - * Prints the schema to the console in a nice tree format. - * @group basic - * @since 1.3.0 - */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println - /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index e86a52c149a2..321e2c783537 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -37,6 +37,15 @@ private[sql] trait Queryable { } } + /** + * Prints the schema to the console in a nice tree format. + * @group basic + * @since 1.3.0 + */ + // scalastyle:off println + def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println + /** * Prints the plans (logical and physical) to the console for debugging purposes. * @since 1.3.0 From 60bfb113325c71491f8dcf98b6036b0caa2144fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 20 Nov 2015 11:43:45 -0800 Subject: [PATCH 1475/1824] [SPARK-11817][SQL] Truncating the fractional seconds to prevent inserting a NULL JIRA: https://issues.apache.org/jira/browse/SPARK-11817 Instead of return None, we should truncate the fractional seconds to prevent inserting NULL. Author: Liang-Chi Hsieh Closes #9834 from viirya/truncate-fractional-sec. --- .../apache/spark/sql/catalyst/util/DateTimeUtils.scala | 5 +++++ .../spark/sql/catalyst/util/DateTimeUtilsSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 17a5527f3fb2..2b9388291948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -327,6 +327,11 @@ object DateTimeUtils { return None } + // Instead of return None, we truncate the fractional seconds to prevent inserting NULL + if (segments(6) > 999999) { + segments(6) = segments(6).toString.take(6).toInt + } + if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || segments(5) < 0 || segments(5) > 59 || segments(6) < 0 || segments(6) > 999999 || segments(7) < 0 || segments(7) > 23 || segments(8) < 0 || segments(8) > 59) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index faca128badfd..0ce5a2fb6950 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -343,6 +343,14 @@ class DateTimeUtilsSuite extends SparkFunSuite { UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) + + // Truncating the fractional seconds + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+00:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123456789+0:00")).get === + c.getTimeInMillis * 1000 + 123456) } test("hours") { From 3b9d2a347f9c796b90852173d84189834e499e25 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2015 12:04:42 -0800 Subject: [PATCH 1476/1824] [SPARK-11819][SQL] nice error message for missing encoder before this PR, when users try to get an encoder for an un-supported class, they will only get a very simple error message like `Encoder for type xxx is not supported`. After this PR, the error message become more friendly, for example: ``` No Encoder found for abc.xyz.NonEncodable - array element class: "abc.xyz.NonEncodable" - field (class: "scala.Array", name: "arrayField") - root class: "abc.xyz.AnotherClass" ``` Author: Wenchen Fan Closes #9810 from cloud-fan/error-message. --- .../spark/sql/catalyst/ScalaReflection.scala | 90 ++++++++++++++----- .../encoders/EncoderErrorMessageSuite.scala | 62 +++++++++++++ 2 files changed, 129 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 33ae700706da..918050b531c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) className match { case "scala.Array" => val TypeRef(_, _, Seq(elementType)) = tpe @@ -320,9 +320,23 @@ object ScalaReflection extends ScalaReflection { } } - /** Returns expressions for extracting all the fields from the given type. */ + /** + * Returns expressions for extracting all the fields from the given type. + * + * If the given type is not supported, i.e. there is no encoder can be built for this type, + * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain + * the type path walked so far and which class we are not supporting. + * There are 4 kinds of type path: + * * the root type: `root class: "abc.xyz.MyClass"` + * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` + * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` + * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` + */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - extractorFor(inputObject, localTypeOf[T]) match { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + extractorFor(inputObject, tpe, walkedTypePath) match { case s: CreateNamedStruct => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } @@ -331,7 +345,28 @@ object ScalaReflection extends ScalaReflection { /** Helper for extracting internal fields from a case class. */ private def extractorFor( inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + tpe: `Type`, + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + + def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = silentSchemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + val clsName = getClassNameFromType(elementType) + val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here + // to trigger the type check. + extractorFor(inputObject, elementType, newPath) + + MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) + } + } + if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { @@ -378,15 +413,16 @@ object ScalaReflection extends ScalaReflection { // For non-primitives, we can just extract the object from the Option and then recurse. case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(optType) val classObj = Utils.classForName(className) val optionObjectType = ObjectType(classObj) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath val unwrapped = UnwrapOption(optionObjectType, inputObject) expressions.If( IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + extractorFor(unwrapped, optType, newPath)) } case t if t <:< localTypeOf[Product] => @@ -412,7 +448,10 @@ object ScalaReflection extends ScalaReflection { val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil }) case t if t <:< localTypeOf[Array[_]] => @@ -500,23 +539,11 @@ object ScalaReflection extends ScalaReflection { Invoke(inputObject, "booleanValue", BooleanType) case other => - throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } } } - - private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = schemaFor(elementType) - if (isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), input, externalDataType) - } - } } /** @@ -561,7 +588,7 @@ trait ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) tpe match { case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => @@ -637,6 +664,23 @@ trait ScalaReflection { } } + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * + * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return + * `NullType` silently instead. + */ + private def silentSchemaFor(tpe: `Type`): Schema = try { + schemaFor(tpe) + } catch { + case _: UnsupportedOperationException => Schema(NullType, nullable = true) + } + + /** Returns the full class name for a type. */ + private def getClassNameFromType(tpe: `Type`): String = { + tpe.erasure.typeSymbol.asClass.fullName + } + /** * Returns classes of input parameters of scala function object. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala index 0b2a10bb04c1..8c766ef82992 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -17,9 +17,22 @@ package org.apache.spark.sql.catalyst.encoders +import scala.reflect.ClassTag + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders +class NonEncodable(i: Int) + +case class ComplexNonEncodable1(name1: NonEncodable) + +case class ComplexNonEncodable2(name2: ComplexNonEncodable1) + +case class ComplexNonEncodable3(name3: Option[NonEncodable]) + +case class ComplexNonEncodable4(name4: Array[NonEncodable]) + +case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]]) class EncoderErrorMessageSuite extends SparkFunSuite { @@ -37,4 +50,53 @@ class EncoderErrorMessageSuite extends SparkFunSuite { intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } } + + test("nice error message for missing encoder") { + val errorMsg1 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage + assert(errorMsg1.contains( + s"""root class: "${clsName[ComplexNonEncodable1]}"""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg2 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage + assert(errorMsg2.contains( + s"""root class: "${clsName[ComplexNonEncodable2]}"""")) + assert(errorMsg2.contains( + s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg3 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage + assert(errorMsg3.contains( + s"""root class: "${clsName[ComplexNonEncodable3]}"""")) + assert(errorMsg3.contains( + s"""field (class: "scala.Option", name: "name3")""")) + assert(errorMsg3.contains( + s"""option value class: "${clsName[NonEncodable]}"""")) + + val errorMsg4 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage + assert(errorMsg4.contains( + s"""root class: "${clsName[ComplexNonEncodable4]}"""")) + assert(errorMsg4.contains( + s"""field (class: "scala.Array", name: "name4")""")) + assert(errorMsg4.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + + val errorMsg5 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage + assert(errorMsg5.contains( + s"""root class: "${clsName[ComplexNonEncodable5]}"""")) + assert(errorMsg5.contains( + s"""field (class: "scala.Option", name: "name5")""")) + assert(errorMsg5.contains( + s"""option value class: "scala.Array"""")) + assert(errorMsg5.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + } + + private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName } From 652def318e47890bd0a0977dc982cc07f99fb06a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 13:17:35 -0800 Subject: [PATCH 1477/1824] [SPARK-11650] Reduce RPC timeouts to speed up slow AkkaUtilsSuite test This patch reduces some RPC timeouts in order to speed up the slow "AkkaUtilsSuite.remote fetch ssl on - untrusted server", which used to take two minutes to run. Author: Josh Rosen Closes #9869 from JoshRosen/SPARK-11650. --- core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 61601016e005..0af4b6098bb0 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -340,10 +340,11 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() + .set("spark.rpc.askTimeout", "5s") + .set("spark.rpc.lookupTimeout", "5s") val securityManagerBad = new SecurityManager(slaveConf) val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) try { slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) fail("should receive either ActorNotFound or TimeoutException") From 9ed4ad4265cf9d3135307eb62dae6de0b220fc21 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 20 Nov 2015 14:19:34 -0800 Subject: [PATCH 1478/1824] [SPARK-11724][SQL] Change casting between int and timestamp to consistently treat int in seconds. Hive has since changed this behavior as well. https://issues.apache.org/jira/browse/HIVE-3454 Author: Nong Li Author: Nong Li Author: Yin Huai Closes #9685 from nongli/spark-11724. --- .../spark/sql/catalyst/expressions/Cast.scala | 6 ++-- .../sql/catalyst/expressions/CastSuite.scala | 16 +++++---- .../apache/spark/sql/DateFunctionsSuite.scala | 3 ++ ...esting-0-237a6af90a857da1efcbe98f6bbbf9d6} | 2 +- ... cast #3-0-76ee270337f664b36cacfc6528ac109 | 1 - ...cast #5-0-dbd7bcd167d322d6617b884c02c7f247 | 1 - ...cast #7-0-1d70654217035f8ce5f64344f4c5a80f | 1 - .../sql/hive/execution/HiveQuerySuite.scala | 34 +++++++++++++------ 8 files changed, 39 insertions(+), 25 deletions(-) rename sql/hive/src/test/resources/golden/{constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 => constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6} (52%) delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5564e242b047..533d17ea5c17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -204,8 +204,8 @@ case class Cast(child: Expression, dataType: DataType) if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong } - // converting milliseconds to us - private[this] def longToTimestamp(t: Long): Long = t * 1000L + // converting seconds to us + private[this] def longToTimestamp(t: Long): Long = t * 1000000L // converting us to seconds private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 1000000L).toLong // converting us to seconds in double @@ -647,7 +647,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def decimalToTimestampCode(d: String): String = s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" private[this] def timestampToIntegerCode(ts: String): String = s"java.lang.Math.floor((double) $ts / 1000000L)" private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index f4db4da7646f..ab77a764483e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -258,8 +258,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from int 2") { checkEvaluation(cast(1, LongType), 1.toLong) - checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) - checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) @@ -348,14 +348,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) @@ -479,10 +479,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, LongType), 15.toLong) checkEvaluation(cast(ts, FloatType), 15.003f) checkEvaluation(cast(ts, DoubleType), 15.003) - checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation(cast(cast(tss, IntegerType), TimestampType), - DateTimeUtils.fromJavaTimestamp(ts)) - checkEvaluation(cast(cast(tss, LongType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + DateTimeUtils.fromJavaTimestamp(ts) * 1000) + checkEvaluation(cast(cast(tss, LongType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation( cast(cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 241cbd011507..a61c3aa48a73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -448,6 +448,9 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + + val now = sql("select unix_timestamp()").collect().head.getLong(0) + checkAnswer(sql(s"select cast ($now as timestamp)"), Row(new java.util.Date(now * 1000))) } test("to_unix_timestamp") { diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 similarity index 52% rename from sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 rename to sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 index 7c41615f8c18..a01c2622c68e 100644 --- a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 +++ b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 @@ -1 +1 @@ -1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL +1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 b/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 deleted file mode 100644 index d00491fd7e5b..000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 +++ /dev/null @@ -1 +0,0 @@ -1 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 deleted file mode 100644 index 84a31a5a6970..000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ /dev/null @@ -1 +0,0 @@ --0.001 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f deleted file mode 100644 index 3fbedf693b51..000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f +++ /dev/null @@ -1 +0,0 @@ --2 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index f0a7a6cc7a1e..8a5acaf3e10b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.sql.Timestamp import java.util.{Locale, TimeZone} import scala.util.Try @@ -248,12 +249,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18, |IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19, |IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20, - |IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, - |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22, - |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23, - |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24 + |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, + |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL22, + |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23 |FROM src LIMIT 1""".stripMargin) + test("constant null testing timestamp") { + val r1 = sql("SELECT IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL20") + .collect().head + assert(new Timestamp(1000) == r1.getTimestamp(0)) + } + createQueryTest("constant array", """ |SELECT sort_array( @@ -603,26 +609,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Jdk version leads to different query output for double, so not use createQueryTest here test("timestamp cast #1") { val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head - assert(0.001 == res.getDouble(0)) + assert(1 == res.getDouble(0)) } createQueryTest("timestamp cast #2", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #3", - "SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #3") { + val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(1200 == res.getInt(0)) + } createQueryTest("timestamp cast #4", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #5", - "SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + test("timestamp cast #5") { + val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + assert(-1 == res.get(0)) + } createQueryTest("timestamp cast #6", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #7", - "SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #7") { + val res = sql("SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(-1200 == res.getInt(0)) + } createQueryTest("timestamp cast #8", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") From be7a2cfd978143f6f265eca63e9e24f755bc9f22 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 20 Nov 2015 14:23:01 -0800 Subject: [PATCH 1479/1824] [SPARK-11870][STREAMING][PYSPARK] Rethrow the exceptions in TransformFunction and TransformFunctionSerializer TransformFunction and TransformFunctionSerializer don't rethrow the exception, so when any exception happens, it just return None. This will cause some weird NPE and confuse people. Author: Shixiong Zhu Closes #9847 from zsxwing/pyspark-streaming-exception. --- python/pyspark/streaming/tests.py | 16 ++++++++++++++++ python/pyspark/streaming/util.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 3403f6d20d78..a0e0267cafa5 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -403,6 +403,22 @@ def func(dstream): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_failed_func(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + raise ValueError("failed") + + input_stream.map(failed_func).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + return + + self.fail("a failed func should throw an error") + class StreamingListenerTests(PySparkStreamingTestCase): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index b20613b1283b..767c732eb90b 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -64,6 +64,7 @@ def call(self, milliseconds, jrdds): return r._jrdd except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunction(%s)" % self.func @@ -95,6 +96,7 @@ def dumps(self, id): return bytearray(self.serializer.dumps((func.func, func.deserializers))) except Exception: traceback.print_exc() + raise def loads(self, data): try: @@ -102,6 +104,7 @@ def loads(self, data): return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer From 89fd9bd06160fa89dedbf685bfe159ffe4a06ec6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 14:31:26 -0800 Subject: [PATCH 1480/1824] [SPARK-11887] Close PersistenceEngine at the end of PersistenceEngineSuite tests In PersistenceEngineSuite, we do not call `close()` on the PersistenceEngine at the end of the test. For the ZooKeeperPersistenceEngine, this causes us to leak a ZooKeeper client, causing the logs of unrelated tests to be periodically spammed with connection error messages from that client: ``` 15/11/20 05:13:35.789 pool-1-thread-1-ScalaTest-running-PersistenceEngineSuite-SendThread(localhost:15741) INFO ClientCnxn: Opening socket connection to server localhost/127.0.0.1:15741. Will not attempt to authenticate using SASL (unknown error) 15/11/20 05:13:35.790 pool-1-thread-1-ScalaTest-running-PersistenceEngineSuite-SendThread(localhost:15741) WARN ClientCnxn: Session 0x15124ff48dd0000 for server null, unexpected error, closing socket connection and attempting reconnect java.net.ConnectException: Connection refused at sun.nio.ch.SocketChannelImpl.checkConnect(Native Method) at sun.nio.ch.SocketChannelImpl.finishConnect(SocketChannelImpl.java:739) at org.apache.zookeeper.ClientCnxnSocketNIO.doTransport(ClientCnxnSocketNIO.java:350) at org.apache.zookeeper.ClientCnxn$SendThread.run(ClientCnxn.java:1068) ``` This patch fixes this by using a `finally` block. Author: Josh Rosen Closes #9864 from JoshRosen/close-zookeeper-client-in-tests. --- .../master/PersistenceEngineSuite.scala | 100 +++++++++--------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 34775577de8a..7a4472867568 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -63,56 +63,60 @@ class PersistenceEngineSuite extends SparkFunSuite { conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { val serializer = new JavaSerializer(conf) val persistenceEngine = persistenceEngineCreator(serializer) - persistenceEngine.persist("test_1", "test_1_value") - assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.persist("test_2", "test_2_value") - assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) - persistenceEngine.unpersist("test_1") - assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.unpersist("test_2") - assert(persistenceEngine.read[String]("test_").isEmpty) - - // Test deserializing objects that contain RpcEndpointRef - val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) try { - // Create a real endpoint so that we can test RpcEndpointRef deserialization - val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { - override val rpcEnv: RpcEnv = testRpcEnv - }) - - val workerToPersist = new WorkerInfo( - id = "test_worker", - host = "127.0.0.1", - port = 10000, - cores = 0, - memory = 0, - endpoint = workerEndpoint, - webUiPort = 0, - publicAddress = "" - ) - - persistenceEngine.addWorker(workerToPersist) - - val (storedApps, storedDrivers, storedWorkers) = - persistenceEngine.readPersistedData(testRpcEnv) - - assert(storedApps.isEmpty) - assert(storedDrivers.isEmpty) - - // Check deserializing WorkerInfo - assert(storedWorkers.size == 1) - val recoveryWorkerInfo = storedWorkers.head - assert(workerToPersist.id === recoveryWorkerInfo.id) - assert(workerToPersist.host === recoveryWorkerInfo.host) - assert(workerToPersist.port === recoveryWorkerInfo.port) - assert(workerToPersist.cores === recoveryWorkerInfo.cores) - assert(workerToPersist.memory === recoveryWorkerInfo.memory) - assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) - assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) - assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = testRpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = + persistenceEngine.readPersistedData(testRpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + testRpcEnv.shutdown() + testRpcEnv.awaitTermination() + } } finally { - testRpcEnv.shutdown() - testRpcEnv.awaitTermination() + persistenceEngine.close() } } From 03ba56d78f50747710d01c27d409ba2be42ae557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Fri, 20 Nov 2015 14:45:40 -0800 Subject: [PATCH 1481/1824] [SPARK-11716][SQL] UDFRegistration just drops the input type when re-creating the UserDefinedFunction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://issues.apache.org/jira/browse/SPARK-11716 This is one is #9739 and a regression test. When commit it, please make sure the author is jbonofre. You can find the original PR at https://github.com/apache/spark/pull/9739 closes #9739 Author: Jean-Baptiste Onofré Author: Yin Huai Closes #9868 from yhuai/SPARK-11716. --- .../apache/spark/sql/UDFRegistration.scala | 48 +++++++++---------- .../scala/org/apache/spark/sql/UDFSuite.scala | 15 ++++++ 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index fc4d0938c533..051694c0d43a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -88,7 +88,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try($inputTypes).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) }""") } @@ -120,7 +120,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -133,7 +133,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -146,7 +146,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -159,7 +159,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -172,7 +172,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -185,7 +185,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -211,7 +211,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -224,7 +224,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -237,7 +237,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -250,7 +250,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -263,7 +263,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -276,7 +276,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -289,7 +289,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -302,7 +302,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -315,7 +315,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -328,7 +328,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -341,7 +341,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -367,7 +367,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -380,7 +380,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -393,7 +393,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -406,7 +406,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 9837fa6bdb35..fd736718af12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -232,4 +232,19 @@ class UDFSuite extends QueryTest with SharedSQLContext { | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp """.stripMargin).toDF(), complexData.select("m", "a", "b")) } + + test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { + val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + + // Without the fix, this will fail because we fail to cast data type of b to string + // because myUDF does not know its input data type. With the fix, this query should not + // fail. + checkAnswer( + testData2.select(myUDF($"a", $"b").as("t")), + testData2.selectExpr("struct(a, b)")) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(a, b) AS t from testData2) tmp").toDF(), + testData2) + } } From a6239d587c638691f52eca3eee905c53fbf35a12 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Fri, 20 Nov 2015 15:10:55 -0800 Subject: [PATCH 1482/1824] [SPARK-11756][SPARKR] Fix use of aliases - SparkR can not output help information for SparkR:::summary correctly Fix use of aliases and changes uses of rdname and seealso `aliases` is the hint for `?` - it should not be linked to some other name - those should be seealso https://cran.r-project.org/web/packages/roxygen2/vignettes/rd.html Clean up usage on family, as multiple use of family with the same rdname is causing duplicated See Also html blocks (like http://spark.apache.org/docs/latest/api/R/count.html) Also changing some rdname for dplyr-like variant for better R user visibility in R doc, eg. rbind, summary, mutate, summarize shivaram yanboliang Author: felixcheung Closes #9750 from felixcheung/rdocaliases. --- R/pkg/R/DataFrame.R | 96 ++++++++++++--------------------------------- R/pkg/R/broadcast.R | 1 - R/pkg/R/generics.R | 12 +++--- R/pkg/R/group.R | 12 +++--- 4 files changed, 37 insertions(+), 84 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 06b0108b1389..8a13e7a36766 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -254,7 +254,6 @@ setMethod("dtypes", #' @family DataFrame functions #' @rdname columns #' @name columns -#' @aliases names #' @export #' @examples #'\dontrun{ @@ -272,7 +271,6 @@ setMethod("columns", }) }) -#' @family DataFrame functions #' @rdname columns #' @name names setMethod("names", @@ -281,7 +279,6 @@ setMethod("names", columns(x) }) -#' @family DataFrame functions #' @rdname columns #' @name names<- setMethod("names<-", @@ -533,14 +530,8 @@ setMethod("distinct", dataFrame(sdf) }) -#' @title Distinct rows in a DataFrame -# -#' @description Returns a new DataFrame containing distinct rows in this DataFrame -#' -#' @family DataFrame functions -#' @rdname unique +#' @rdname distinct #' @name unique -#' @aliases distinct setMethod("unique", signature(x = "DataFrame"), function(x) { @@ -557,7 +548,7 @@ setMethod("unique", #' #' @family DataFrame functions #' @rdname sample -#' @aliases sample_frac +#' @name sample #' @export #' @examples #'\dontrun{ @@ -579,7 +570,6 @@ setMethod("sample", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname sample #' @name sample_frac setMethod("sample_frac", @@ -589,16 +579,15 @@ setMethod("sample_frac", sample(x, withReplacement, fraction) }) -#' Count +#' nrow #' #' Returns the number of rows in a DataFrame #' #' @param x A SparkSQL DataFrame #' #' @family DataFrame functions -#' @rdname count +#' @rdname nrow #' @name count -#' @aliases nrow #' @export #' @examples #'\dontrun{ @@ -614,14 +603,8 @@ setMethod("count", callJMethod(x@sdf, "count") }) -#' @title Number of rows for a DataFrame -#' @description Returns number of rows in a DataFrames -#' #' @name nrow -#' -#' @family DataFrame functions #' @rdname nrow -#' @aliases count setMethod("nrow", signature(x = "DataFrame"), function(x) { @@ -870,7 +853,6 @@ setMethod("toRDD", #' @param x a DataFrame #' @return a GroupedData #' @seealso GroupedData -#' @aliases group_by #' @family DataFrame functions #' @rdname groupBy #' @name groupBy @@ -896,7 +878,6 @@ setMethod("groupBy", groupedData(sgd) }) -#' @family DataFrame functions #' @rdname groupBy #' @name group_by setMethod("group_by", @@ -913,7 +894,6 @@ setMethod("group_by", #' @family DataFrame functions #' @rdname agg #' @name agg -#' @aliases summarize #' @export setMethod("agg", signature(x = "DataFrame"), @@ -921,7 +901,6 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @family DataFrame functions #' @rdname agg #' @name summarize setMethod("summarize", @@ -1092,7 +1071,6 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' @family DataFrame functions #' @rdname subset #' @name subset -#' @aliases [ #' @family subsetting functions #' @examples #' \dontrun{ @@ -1216,7 +1194,7 @@ setMethod("selectExpr", #' @family DataFrame functions #' @rdname withColumn #' @name withColumn -#' @aliases mutate transform +#' @seealso \link{rename} \link{mutate} #' @export #' @examples #'\dontrun{ @@ -1231,7 +1209,6 @@ setMethod("withColumn", function(x, colName, col) { select(x, x$"*", alias(col, colName)) }) - #' Mutate #' #' Return a new DataFrame with the specified columns added. @@ -1240,9 +1217,9 @@ setMethod("withColumn", #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. #' @family DataFrame functions -#' @rdname withColumn +#' @rdname mutate #' @name mutate -#' @aliases withColumn transform +#' @seealso \link{rename} \link{withColumn} #' @export #' @examples #'\dontrun{ @@ -1273,17 +1250,15 @@ setMethod("mutate", }) #' @export -#' @family DataFrame functions -#' @rdname withColumn +#' @rdname mutate #' @name transform -#' @aliases withColumn mutate setMethod("transform", signature(`_data` = "DataFrame"), function(`_data`, ...) { mutate(`_data`, ...) }) -#' WithColumnRenamed +#' rename #' #' Rename an existing column in a DataFrame. #' @@ -1292,8 +1267,9 @@ setMethod("transform", #' @param newCol The new column name. #' @return A DataFrame with the column name changed. #' @family DataFrame functions -#' @rdname withColumnRenamed +#' @rdname rename #' @name withColumnRenamed +#' @seealso \link{mutate} #' @export #' @examples #'\dontrun{ @@ -1316,17 +1292,9 @@ setMethod("withColumnRenamed", select(x, cols) }) -#' Rename -#' -#' Rename an existing column in a DataFrame. -#' -#' @param x A DataFrame -#' @param newCol A named pair of the form new_column_name = existing_column -#' @return A DataFrame with the column name changed. -#' @family DataFrame functions -#' @rdname withColumnRenamed +#' @param newColPair A named pair of the form new_column_name = existing_column +#' @rdname rename #' @name rename -#' @aliases withColumnRenamed #' @export #' @examples #'\dontrun{ @@ -1371,7 +1339,6 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @family DataFrame functions #' @rdname arrange #' @name arrange -#' @aliases orderby #' @export #' @examples #'\dontrun{ @@ -1395,8 +1362,8 @@ setMethod("arrange", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname arrange +#' @name arrange #' @export setMethod("arrange", signature(x = "DataFrame", col = "character"), @@ -1427,9 +1394,9 @@ setMethod("arrange", do.call("arrange", c(x, jcols)) }) -#' @family DataFrame functions #' @rdname arrange -#' @name orderby +#' @name orderBy +#' @export setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { @@ -1492,6 +1459,7 @@ setMethod("where", #' @family DataFrame functions #' @rdname join #' @name join +#' @seealso \link{merge} #' @export #' @examples #'\dontrun{ @@ -1528,9 +1496,7 @@ setMethod("join", dataFrame(sdf) }) -#' #' @name merge -#' @aliases join #' @title Merges two data frames #' @param x the first data frame to be joined #' @param y the second data frame to be joined @@ -1550,6 +1516,7 @@ setMethod("join", #' outer join will be returned. #' @family DataFrame functions #' @rdname merge +#' @seealso \link{join} #' @export #' @examples #'\dontrun{ @@ -1671,7 +1638,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { cols } -#' UnionAll +#' rbind #' #' Return a new DataFrame containing the union of rows in this DataFrame #' and another DataFrame. This is equivalent to `UNION ALL` in SQL. @@ -1681,7 +1648,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. #' @family DataFrame functions -#' @rdname unionAll +#' @rdname rbind #' @name unionAll #' @export #' @examples @@ -1700,13 +1667,11 @@ setMethod("unionAll", }) #' @title Union two or more DataFrames -#' #' @description Returns a new DataFrame containing rows of all parameters. #' -#' @family DataFrame functions #' @rdname rbind #' @name rbind -#' @aliases unionAll +#' @export setMethod("rbind", signature(... = "DataFrame"), function(x, ..., deparse.level = 1) { @@ -1795,7 +1760,6 @@ setMethod("except", #' @family DataFrame functions #' @rdname write.df #' @name write.df -#' @aliases saveDF #' @export #' @examples #'\dontrun{ @@ -1828,7 +1792,6 @@ setMethod("write.df", callJMethod(df@sdf, "save", source, jmode, options) }) -#' @family DataFrame functions #' @rdname write.df #' @name saveDF #' @export @@ -1891,7 +1854,7 @@ setMethod("saveAsTable", callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) }) -#' describe +#' summary #' #' Computes statistics for numeric columns. #' If no columns are given, this function computes statistics for all numerical columns. @@ -1901,9 +1864,8 @@ setMethod("saveAsTable", #' @param ... Additional expressions #' @return A DataFrame #' @family DataFrame functions -#' @rdname describe +#' @rdname summary #' @name describe -#' @aliases summary #' @export #' @examples #'\dontrun{ @@ -1923,8 +1885,7 @@ setMethod("describe", dataFrame(sdf) }) -#' @family DataFrame functions -#' @rdname describe +#' @rdname summary #' @name describe setMethod("describe", signature(x = "DataFrame"), @@ -1934,11 +1895,6 @@ setMethod("describe", dataFrame(sdf) }) -#' @title Summary -#' -#' @description Computes statistics for numeric columns of the DataFrame -#' -#' @family DataFrame functions #' @rdname summary #' @name summary setMethod("summary", @@ -1966,7 +1922,6 @@ setMethod("summary", #' @family DataFrame functions #' @rdname nafunctions #' @name dropna -#' @aliases na.omit #' @export #' @examples #'\dontrun{ @@ -1993,7 +1948,6 @@ setMethod("dropna", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname nafunctions #' @name na.omit #' @export @@ -2019,9 +1973,7 @@ setMethod("na.omit", #' type are ignored. For example, if value is a character, and #' subset contains a non-character column, then the non-character #' column is simply ignored. -#' @return A DataFrame #' -#' @family DataFrame functions #' @rdname nafunctions #' @name fillna #' @export diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 2403925b267c..38f0eed95e06 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -51,7 +51,6 @@ Broadcast <- function(id, value, jBroadcastRef, objName) { # # @param bcast The broadcast variable to get # @rdname broadcast -# @aliases value,Broadcast-method setMethod("value", signature(bcast = "Broadcast"), function(bcast) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 71004a05ba61..1b3f10ea0464 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -397,7 +397,7 @@ setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) #' @export setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) -#' @rdname describe +#' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -459,11 +459,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) -#' rdname merge +#' @rdname merge #' @export setGeneric("merge") -#' @rdname withColumn +#' @rdname mutate #' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) @@ -475,7 +475,7 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) #' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) @@ -553,7 +553,7 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) -#' @rdname unionAll +#' @rdname rbind #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) @@ -565,7 +565,7 @@ setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index e5f702faee65..23b49aebda05 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -68,7 +68,7 @@ setMethod("count", dataFrame(callJMethod(x@sgd, "count")) }) -#' Agg +#' summarize #' #' Aggregates on the entire DataFrame without groups. #' The resulting DataFrame will also contain the grouping columns. @@ -78,12 +78,14 @@ setMethod("count", #' #' @param x a GroupedData #' @return a DataFrame -#' @rdname agg +#' @rdname summarize +#' @name agg #' @family agg_funcs #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' -#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df4 <- summarize(df, ageSum = max(df$age)) #' } setMethod("agg", signature(x = "GroupedData"), @@ -110,8 +112,8 @@ setMethod("agg", dataFrame(sdf) }) -#' @rdname agg -#' @aliases agg +#' @rdname summarize +#' @name summarize setMethod("summarize", signature(x = "GroupedData"), function(x, ...) { From 4b84c72dfbb9ddb415fee35f69305b5d7b280891 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:17:17 -0800 Subject: [PATCH 1483/1824] [SPARK-11636][SQL] Support classes defined in the REPL with Encoders #theScaryParts (i.e. changes to the repl, executor classloaders and codegen)... Author: Michael Armbrust Author: Yin Huai Closes #9825 from marmbrus/dataset-replClasses2. --- .../org/apache/spark/repl/SparkIMain.scala | 14 +++++++---- .../org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++++ .../spark/repl/ExecutorClassLoader.scala | 8 ++++++- .../expressions/codegen/CodeGenerator.scala | 4 ++-- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 4ee605fd7f11..829b12269fd2 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1221,10 +1221,16 @@ import org.apache.spark.annotation.DeveloperApi ) } - val preamble = """ - |class %s extends Serializable { - | %s%s%s - """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute)) + val preamble = s""" + |class ${lineRep.readName} extends Serializable { + | ${envLines.map(" " + _ + ";\n").mkString} + | $importsPreamble + | + | // If we need to construct any objects defined in the REPL on an executor we will need + | // to pass the outer scope to the appropriate encoder. + | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this) + | ${indentCode(toCompute)} + """.stripMargin val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 5674dcd669be..081aa03002cc 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -262,6 +262,9 @@ class ReplSuite extends SparkFunSuite { |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -278,6 +281,27 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("java.lang.ClassNotFoundException", output) } + test("Datasets and encoders") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] with Serializable { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 3d2d235a00c9..a976e96809cb 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -65,7 +65,13 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader case e: ClassNotFoundException => { val classOption = findClassLocally(name) classOption match { - case None => throw new ClassNotFoundException(name, e) + case None => + // If this class has a cause, it will break the internal assumption of Janino + // (the compiler used for Spark SQL code-gen). + // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see + // its behavior will be changed if there is a cause and the compilation + // of generated class will fail. + throw new ClassNotFoundException(name) case Some(a) => a } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1b7260cdfe51..2f3d6aeb86c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ - +import org.apache.spark.util.Utils /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. @@ -536,7 +536,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( From ed47b1e660b830e2d4fac8d6df93f634b260393c Mon Sep 17 00:00:00 2001 From: Vikas Nelamangala Date: Fri, 20 Nov 2015 15:18:41 -0800 Subject: [PATCH 1484/1824] [SPARK-11549][DOCS] Replace example code in mllib-evaluation-metrics.md using include_example Author: Vikas Nelamangala Closes #9689 from vikasnp/master. --- docs/mllib-evaluation-metrics.md | 940 +----------------- ...avaBinaryClassificationMetricsExample.java | 113 +++ ...ultiLabelClassificationMetricsExample.java | 80 ++ ...ulticlassClassificationMetricsExample.java | 97 ++ .../mllib/JavaRankingMetricsExample.java | 176 ++++ .../mllib/JavaRegressionMetricsExample.java | 91 ++ .../binary_classification_metrics_example.py | 55 + .../mllib/multi_class_metrics_example.py | 69 ++ .../mllib/multi_label_metrics_example.py | 61 ++ .../python/mllib/ranking_metrics_example.py | 55 + .../mllib/regression_metrics_example.py | 59 ++ .../BinaryClassificationMetricsExample.scala | 103 ++ .../mllib/MultiLabelMetricsExample.scala | 69 ++ .../mllib/MulticlassMetricsExample.scala | 99 ++ .../mllib/RankingMetricsExample.scala | 110 ++ .../mllib/RegressionMetricsExample.scala | 67 ++ 16 files changed, 1319 insertions(+), 925 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java create mode 100644 examples/src/main/python/mllib/binary_classification_metrics_example.py create mode 100644 examples/src/main/python/mllib/multi_class_metrics_example.py create mode 100644 examples/src/main/python/mllib/multi_label_metrics_example.py create mode 100644 examples/src/main/python/mllib/ranking_metrics_example.py create mode 100644 examples/src/main/python/mllib/regression_metrics_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index f73eff637dc3..6924037b941f 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -104,214 +104,21 @@ data, and evaluate the performance of the algorithm by several binary evaluation
    Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`BinaryClassificationMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training) - -// Clear the prediction threshold so the model will return probabilities -model.clearThreshold - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new BinaryClassificationMetrics(predictionAndLabels) - -// Precision by threshold -val precision = metrics.precisionByThreshold -precision.foreach { case (t, p) => - println(s"Threshold: $t, Precision: $p") -} - -// Recall by threshold -val recall = metrics.recallByThreshold -recall.foreach { case (t, r) => - println(s"Threshold: $t, Recall: $r") -} - -// Precision-Recall Curve -val PRC = metrics.pr - -// F-measure -val f1Score = metrics.fMeasureByThreshold -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 1") -} - -val beta = 0.5 -val fScore = metrics.fMeasureByThreshold(beta) -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 0.5") -} - -// AUPRC -val auPRC = metrics.areaUnderPR -println("Area under precision-recall curve = " + auPRC) - -// Compute thresholds used in ROC and PR curves -val thresholds = precision.map(_._1) - -// ROC Curve -val roc = metrics.roc - -// AUROC -val auROC = metrics.areaUnderROC -println("Area under ROC = " + auROC) - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala %}
    Refer to the [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) and [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class BinaryClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_binary_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training.rdd()); - - // Clear the prediction threshold so the model will return probabilities - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); - - // Precision by threshold - JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); - System.out.println("Precision by threshold: " + precision.toArray()); - - // Recall by threshold - JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); - System.out.println("Recall by threshold: " + recall.toArray()); - - // F Score by threshold - JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); - System.out.println("F1 Score by threshold: " + f1Score.toArray()); - - JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); - System.out.println("F2 Score by threshold: " + f2Score.toArray()); - - // Precision-recall curve - JavaRDD> prc = metrics.pr().toJavaRDD(); - System.out.println("Precision-recall curve: " + prc.toArray()); - - // Thresholds - JavaRDD thresholds = precision.map( - new Function, Double>() { - public Double call (Tuple2 t) { - return new Double(t._1().toString()); - } - } - ); - - // ROC Curve - JavaRDD> roc = metrics.roc().toJavaRDD(); - System.out.println("ROC curve: " + roc.toArray()); - - // AUPRC - System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); - - // AUROC - System.out.println("Area under ROC = " + metrics.areaUnderROC()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java %}
    Refer to the [`BinaryClassificationMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.BinaryClassificationMetrics) and [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.evaluation import BinaryClassificationMetrics -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils - -# Several of the methods available in scala are currently missing from pyspark - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = BinaryClassificationMetrics(predictionAndLabels) - -# Area under precision-recall curve -print("Area under PR = %s" % metrics.areaUnderPR) - -# Area under ROC curve -print("Area under ROC = %s" % metrics.areaUnderROC) - -{% endhighlight %} - +{% include_example python/mllib/binary_classification_metrics_example.py %}
    @@ -433,204 +240,21 @@ the data, and evaluate the performance of the algorithm by several multiclass cl
    Refer to the [`MulticlassMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MulticlassMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training) - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new MulticlassMetrics(predictionAndLabels) - -// Confusion matrix -println("Confusion matrix:") -println(metrics.confusionMatrix) - -// Overall Statistics -val precision = metrics.precision -val recall = metrics.recall // same as true positive rate -val f1Score = metrics.fMeasure -println("Summary Statistics") -println(s"Precision = $precision") -println(s"Recall = $recall") -println(s"F1 Score = $f1Score") - -// Precision by label -val labels = metrics.labels -labels.foreach { l => - println(s"Precision($l) = " + metrics.precision(l)) -} - -// Recall by label -labels.foreach { l => - println(s"Recall($l) = " + metrics.recall(l)) -} - -// False positive rate by label -labels.foreach { l => - println(s"FPR($l) = " + metrics.falsePositiveRate(l)) -} - -// F-measure by label -labels.foreach { l => - println(s"F1-Score($l) = " + metrics.fMeasure(l)) -} - -// Weighted stats -println(s"Weighted precision: ${metrics.weightedPrecision}") -println(s"Weighted recall: ${metrics.weightedRecall}") -println(s"Weighted F1 score: ${metrics.weightedFMeasure}") -println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala %}
    Refer to the [`MulticlassMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MulticlassMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class MulticlassClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_multiclass_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training.rdd()); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - - // Confusion matrix - Matrix confusion = metrics.confusionMatrix(); - System.out.println("Confusion matrix: \n" + confusion); - - // Overall statistics - System.out.println("Precision = " + metrics.precision()); - System.out.println("Recall = " + metrics.recall()); - System.out.println("F1 Score = " + metrics.fMeasure()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length; i++) { - System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); - } - - //Weighted stats - System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); - System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); - System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); - System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} + {% include_example java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java %}
    Refer to the [`MulticlassMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MulticlassMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.util import MLUtils -from pyspark.mllib.evaluation import MulticlassMetrics - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training, numClasses=3) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = MulticlassMetrics(predictionAndLabels) - -# Overall statistics -precision = metrics.precision() -recall = metrics.recall() -f1Score = metrics.fMeasure() -print("Summary Stats") -print("Precision = %s" % precision) -print("Recall = %s" % recall) -print("F1 Score = %s" % f1Score) - -# Statistics by class -labels = data.map(lambda lp: lp.label).distinct().collect() -for label in sorted(labels): - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) - -# Weighted stats -print("Weighted recall = %s" % metrics.weightedRecall) -print("Weighted precision = %s" % metrics.weightedPrecision) -print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) -print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) -print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) -{% endhighlight %} +{% include_example python/mllib/multi_class_metrics_example.py %}
    @@ -766,154 +390,21 @@ True classes:
    Refer to the [`MultilabelMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MultilabelMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.MultilabelMetrics -import org.apache.spark.rdd.RDD; - -val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( - Seq((Array(0.0, 1.0), Array(0.0, 2.0)), - (Array(0.0, 2.0), Array(0.0, 1.0)), - (Array(), Array(0.0)), - (Array(2.0), Array(2.0)), - (Array(2.0, 0.0), Array(2.0, 0.0)), - (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), - (Array(1.0), Array(1.0, 2.0))), 2) - -// Instantiate metrics object -val metrics = new MultilabelMetrics(scoreAndLabels) - -// Summary stats -println(s"Recall = ${metrics.recall}") -println(s"Precision = ${metrics.precision}") -println(s"F1 measure = ${metrics.f1Measure}") -println(s"Accuracy = ${metrics.accuracy}") - -// Individual label stats -metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) -metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) -metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) - -// Micro stats -println(s"Micro recall = ${metrics.microRecall}") -println(s"Micro precision = ${metrics.microPrecision}") -println(s"Micro F1 measure = ${metrics.microF1Measure}") - -// Hamming loss -println(s"Hamming loss = ${metrics.hammingLoss}") - -// Subset accuracy -println(s"Subset accuracy = ${metrics.subsetAccuracy}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala %}
    Refer to the [`MultilabelMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MultilabelMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.evaluation.MultilabelMetrics; -import org.apache.spark.SparkConf; -import java.util.Arrays; -import java.util.List; - -public class MultilabelClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - - List> data = Arrays.asList( - new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), - new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{}, new double[]{0.0}), - new Tuple2(new double[]{2.0}, new double[]{2.0}), - new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), - new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) - ); - JavaRDD> scoreAndLabels = sc.parallelize(data); - - // Instantiate metrics object - MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); - - // Summary stats - System.out.format("Recall = %f\n", metrics.recall()); - System.out.format("Precision = %f\n", metrics.precision()); - System.out.format("F1 measure = %f\n", metrics.f1Measure()); - System.out.format("Accuracy = %f\n", metrics.accuracy()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length - 1; i++) { - System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); - } - - // Micro stats - System.out.format("Micro recall = %f\n", metrics.microRecall()); - System.out.format("Micro precision = %f\n", metrics.microPrecision()); - System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); - - // Hamming loss - System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); - - // Subset accuracy - System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); - - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java %}
    Refer to the [`MultilabelMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MultilabelMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.evaluation import MultilabelMetrics - -scoreAndLabels = sc.parallelize([ - ([0.0, 1.0], [0.0, 2.0]), - ([0.0, 2.0], [0.0, 1.0]), - ([], [0.0]), - ([2.0], [2.0]), - ([2.0, 0.0], [2.0, 0.0]), - ([0.0, 1.0, 2.0], [0.0, 1.0]), - ([1.0], [1.0, 2.0])]) - -# Instantiate metrics object -metrics = MultilabelMetrics(scoreAndLabels) - -# Summary stats -print("Recall = %s" % metrics.recall()) -print("Precision = %s" % metrics.precision()) -print("F1 measure = %s" % metrics.f1Measure()) -print("Accuracy = %s" % metrics.accuracy) - -# Individual label stats -labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() -for label in labels: - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) - -# Micro stats -print("Micro precision = %s" % metrics.microPrecision) -print("Micro recall = %s" % metrics.microRecall) -print("Micro F1 measure = %s" % metrics.microF1Measure) - -# Hamming loss -print("Hamming loss = %s" % metrics.hammingLoss) - -# Subset accuracy -print("Subset accuracy = %s" % metrics.subsetAccuracy) - -{% endhighlight %} +{% include_example python/mllib/multi_label_metrics_example.py %}
    @@ -1027,280 +518,21 @@ expanded world of non-positive weights are "the same as never having interacted
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RankingMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} -import org.apache.spark.mllib.recommendation.{ALS, Rating} - -// Read in the ratings data -val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => - val fields = line.split("::") - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) -}.cache() - -// Map ratings to 1 or 0, 1 indicating a movie that should be recommended -val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() - -// Summarize ratings -val numRatings = ratings.count() -val numUsers = ratings.map(_.user).distinct().count() -val numMovies = ratings.map(_.product).distinct().count() -println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") - -// Build the model -val numIterations = 10 -val rank = 10 -val lambda = 0.01 -val model = ALS.train(ratings, rank, numIterations, lambda) - -// Define a function to scale ratings from 0 to 1 -def scaledRating(r: Rating): Rating = { - val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) - Rating(r.user, r.product, scaledRating) -} - -// Get sorted top ten predictions for each user and then scale from [0, 1] -val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => - (user, recs.map(scaledRating)) -} - -// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document -// Compare with top ten most relevant documents -val userMovies = binarizedRatings.groupBy(_.user) -val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => - (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) -} - -// Instantiate metrics object -val metrics = new RankingMetrics(relevantDocuments) - -// Precision at K -Array(1, 3, 5).foreach{ k => - println(s"Precision at $k = ${metrics.precisionAt(k)}") -} - -// Mean average precision -println(s"Mean average precision = ${metrics.meanAveragePrecision}") - -// Normalized discounted cumulative gain -Array(1, 3, 5).foreach{ k => - println(s"NDCG at $k = ${metrics.ndcgAt(k)}") -} - -// Get predictions for each data point -val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) -val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) -val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => - (predicted, actual) -} - -// Get the RMSE using regression metrics -val regressionMetrics = new RegressionMetrics(predictionsAndLabels) -println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${regressionMetrics.r2}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) and [`RankingMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RankingMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; -import java.util.*; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.mllib.evaluation.RankingMetrics; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.Rating; - -// Read in the ratings data -public class Ranking { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - String path = "data/mllib/sample_movielens_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String line) { - String[] parts = line.split("::"); - return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); - } - } - ); - ratings.cache(); - - // Train an ALS model - final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); - - // Get top 10 recommendations for every user and scale ratings from 0 to 1 - JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); - JavaRDD> userRecsScaled = userRecs.map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 t) { - Rating[] scaledRatings = new Rating[t._2().length]; - for (int i = 0; i < scaledRatings.length; i++) { - double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); - scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); - } - return new Tuple2(t._1(), scaledRatings); - } - } - ); - JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); - - // Map ratings to 1 or 0, 1 indicating a movie that should be recommended - JavaRDD binarizedRatings = ratings.map( - new Function() { - public Rating call(Rating r) { - double binaryRating; - if (r.rating() > 0.0) { - binaryRating = 1.0; - } - else { - binaryRating = 0.0; - } - return new Rating(r.user(), r.product(), binaryRating); - } - } - ); - - // Group ratings by common user - JavaPairRDD> userMovies = binarizedRatings.groupBy( - new Function() { - public Object call(Rating r) { - return r.user(); - } - } - ); - - // Get true relevant documents from all user ratings - JavaPairRDD> userMoviesList = userMovies.mapValues( - new Function, List>() { - public List call(Iterable docs) { - List products = new ArrayList(); - for (Rating r : docs) { - if (r.rating() > 0.0) { - products.add(r.product()); - } - } - return products; - } - } - ); - - // Extract the product id from each recommendation - JavaPairRDD> userRecommendedList = userRecommended.mapValues( - new Function>() { - public List call(Rating[] docs) { - List products = new ArrayList(); - for (Rating r : docs) { - products.add(r.product()); - } - return products; - } - } - ); - JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); - - // Instantiate the metrics object - RankingMetrics metrics = RankingMetrics.of(relevantDocs); - - // Precision and NDCG at k - Integer[] kVector = {1, 3, 5}; - for (Integer k : kVector) { - System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); - System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); - } - - // Mean average precision - System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); - - // Evaluate the model using numerical ratings and regression metrics - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - - // Create regression metrics object - RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); - - // Root mean squared error - System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R-squared = %f\n", regressionMetrics.r2()); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, Rating -from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics - -# Read in the ratings data -lines = sc.textFile("data/mllib/sample_movielens_data.txt") - -def parseLine(line): - fields = line.split("::") - return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) -ratings = lines.map(lambda r: parseLine(r)) - -# Train a model on to predict user-product ratings -model = ALS.train(ratings, 10, 10, 0.01) - -# Get predicted ratings on all existing user-product pairs -testData = ratings.map(lambda p: (p.user, p.product)) -predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) - -ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) -scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) - -# Instantiate regression metrics to compare predicted and actual ratings -metrics = RegressionMetrics(scoreAndLabels) - -# Root mean sqaured error -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -{% endhighlight %} +{% include_example python/mllib/ranking_metrics_example.py %}
    @@ -1350,163 +582,21 @@ and evaluate the performance of the algorithm by several regression metrics.
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.util.MLUtils - -// Load the data -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() - -// Build the model -val numIterations = 100 -val model = LinearRegressionWithSGD.train(data, numIterations) - -// Get predictions -val valuesAndPreds = data.map{ point => - val prediction = model.predict(point.features) - (prediction, point.label) -} - -// Instantiate metrics object -val metrics = new RegressionMetrics(valuesAndPreds) - -// Squared error -println(s"MSE = ${metrics.meanSquaredError}") -println(s"RMSE = ${metrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${metrics.r2}") - -// Mean absolute error -println(s"MAE = ${metrics.meanAbsoluteError}") - -// Explained variance -println(s"Explained variance = ${metrics.explainedVariance}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.SparkConf; - -public class LinearRegression { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_linear_regression_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(" "); - double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) - v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } - } - ); - parsedData.cache(); - - // Building the model - int numIterations = 100; - final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); - - // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); - - // Instantiate metrics object - RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); - - // Squared error - System.out.format("MSE = %f\n", metrics.meanSquaredError()); - System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R Squared = %f\n", metrics.r2()); - - // Mean absolute error - System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); - - // Explained variance - System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD -from pyspark.mllib.evaluation import RegressionMetrics -from pyspark.mllib.linalg import DenseVector - -# Load and parse the data -def parsePoint(line): - values = line.split() - return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) - -data = sc.textFile("data/mllib/sample_linear_regression_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = LinearRegressionWithSGD.train(parsedData) - -# Get predictions -valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) - -# Instantiate metrics object -metrics = RegressionMetrics(valuesAndPreds) - -# Squared Error -print("MSE = %s" % metrics.meanSquaredError) -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -# Mean absolute error -print("MAE = %s" % metrics.meanAbsoluteError) - -# Explained variance -print("Explained variance = %s" % metrics.explainedVariance) - -{% endhighlight %} +{% include_example python/mllib/regression_metrics_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java new file mode 100644 index 000000000000..980a9108af53 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaBinaryClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = + data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call(Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java new file mode 100644 index 000000000000..b54e1ea3f2bc --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.rdd.RDD; +import org.apache.spark.SparkConf; +// $example off$ +import org.apache.spark.SparkContext; + +public class JavaMultiLabelClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision + (metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics + .labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure + (metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java new file mode 100644 index 000000000000..21f628fb51b6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaMulticlassClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision + (metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics + .labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure + (metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java new file mode 100644 index 000000000000..7c4c97e74681 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.*; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaRankingMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Ranking Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double + .parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join + (userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java new file mode 100644 index 000000000000..d2efc6bf9777 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRegressionMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Regression Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), + numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "target/tmp/LogisticRegressionModel"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py new file mode 100644 index 000000000000..437acb998acc --- /dev/null +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Binary Classification Metrics Example. +""" +from __future__ import print_function +import sys +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinaryClassificationMetricsExample") + sqlContext = SQLContext(sc) + # $example on$ + # Several of the methods available in scala are currently missing from pyspark + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = BinaryClassificationMetrics(predictionAndLabels) + + # Area under precision-recall curve + print("Area under PR = %s" % metrics.areaUnderPR) + + # Area under ROC curve + print("Area under ROC = %s" % metrics.areaUnderROC) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_class_metrics_example.py b/examples/src/main/python/mllib/multi_class_metrics_example.py new file mode 100644 index 000000000000..cd56b3c97c77 --- /dev/null +++ b/examples/src/main/python/mllib/multi_class_metrics_example.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiClassMetricsExample") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = MulticlassMetrics(predictionAndLabels) + + # Overall statistics + precision = metrics.precision() + recall = metrics.recall() + f1Score = metrics.fMeasure() + print("Summary Stats") + print("Precision = %s" % precision) + print("Recall = %s" % recall) + print("F1 Score = %s" % f1Score) + + # Statistics by class + labels = data.map(lambda lp: lp.label).distinct().collect() + for label in sorted(labels): + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) + + # Weighted stats + print("Weighted recall = %s" % metrics.weightedRecall) + print("Weighted precision = %s" % metrics.weightedPrecision) + print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) + print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) + print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_label_metrics_example.py b/examples/src/main/python/mllib/multi_label_metrics_example.py new file mode 100644 index 000000000000..960ade659737 --- /dev/null +++ b/examples/src/main/python/mllib/multi_label_metrics_example.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.evaluation import MultilabelMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiLabelMetricsExample") + # $example on$ + scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + + # Instantiate metrics object + metrics = MultilabelMetrics(scoreAndLabels) + + # Summary stats + print("Recall = %s" % metrics.recall()) + print("Precision = %s" % metrics.precision()) + print("F1 measure = %s" % metrics.f1Measure()) + print("Accuracy = %s" % metrics.accuracy) + + # Individual label stats + labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() + for label in labels: + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) + + # Micro stats + print("Micro precision = %s" % metrics.microPrecision) + print("Micro recall = %s" % metrics.microRecall) + print("Micro F1 measure = %s" % metrics.microF1Measure) + + # Hamming loss + print("Hamming loss = %s" % metrics.hammingLoss) + + # Subset accuracy + print("Subset accuracy = %s" % metrics.subsetAccuracy) + # $example off$ diff --git a/examples/src/main/python/mllib/ranking_metrics_example.py b/examples/src/main/python/mllib/ranking_metrics_example.py new file mode 100644 index 000000000000..327791966c90 --- /dev/null +++ b/examples/src/main/python/mllib/ranking_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Ranking Metrics Example") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Read in the ratings data + lines = sc.textFile("data/mllib/sample_movielens_data.txt") + + def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) + ratings = lines.map(lambda r: parseLine(r)) + + # Train a model on to predict user-product ratings + model = ALS.train(ratings, 10, 10, 0.01) + + # Get predicted ratings on all existing user-product pairs + testData = ratings.map(lambda p: (p.user, p.product)) + predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + + ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) + scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + + # Instantiate regression metrics to compare predicted and actual ratings + metrics = RegressionMetrics(scoreAndLabels) + + # Root mean sqaured error + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + # $example off$ diff --git a/examples/src/main/python/mllib/regression_metrics_example.py b/examples/src/main/python/mllib/regression_metrics_example.py new file mode 100644 index 000000000000..a3a83aafd7a1 --- /dev/null +++ b/examples/src/main/python/mllib/regression_metrics_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# $example on$ +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Regression Metrics Example") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), + DenseVector([float(x.split(':')[1]) for x in values[1:]])) + + data = sc.textFile("data/mllib/sample_linear_regression_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = LinearRegressionWithSGD.train(parsedData) + + # Get predictions + valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + + # Instantiate metrics object + metrics = RegressionMetrics(valuesAndPreds) + + # Squared Error + print("MSE = %s" % metrics.meanSquaredError) + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + + # Mean absolute error + print("MAE = %s" % metrics.meanAbsoluteError) + + # Explained variance + print("Explained variance = %s" % metrics.explainedVariance) + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala new file mode 100644 index 000000000000..13a37827ab93 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object BinaryClassificationMetricsExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("BinaryClassificationMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new BinaryClassificationMetrics(predictionAndLabels) + + // Precision by threshold + val precision = metrics.precisionByThreshold + precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") + } + + // Recall by threshold + val recall = metrics.recallByThreshold + recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") + } + + // Precision-Recall Curve + val PRC = metrics.pr + + // F-measure + val f1Score = metrics.fMeasureByThreshold + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") + } + + val beta = 0.5 + val fScore = metrics.fMeasureByThreshold(beta) + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") + } + + // AUPRC + val auPRC = metrics.areaUnderPR + println("Area under precision-recall curve = " + auPRC) + + // Compute thresholds used in ROC and PR curves + val thresholds = precision.map(_._1) + + // ROC Curve + val roc = metrics.roc + + // AUROC + val auROC = metrics.areaUnderROC + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala new file mode 100644 index 000000000000..4503c15360ad --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MultiLabelMetricsExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MultiLabelMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array.empty[Double], Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + + // Instantiate metrics object + val metrics = new MultilabelMetrics(scoreAndLabels) + + // Summary stats + println(s"Recall = ${metrics.recall}") + println(s"Precision = ${metrics.precision}") + println(s"F1 measure = ${metrics.f1Measure}") + println(s"Accuracy = ${metrics.accuracy}") + + // Individual label stats + metrics.labels.foreach(label => + println(s"Class $label precision = ${metrics.precision(label)}")) + metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) + metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + + // Micro stats + println(s"Micro recall = ${metrics.microRecall}") + println(s"Micro precision = ${metrics.microPrecision}") + println(s"Micro F1 measure = ${metrics.microF1Measure}") + + // Hamming loss + println(s"Hamming loss = ${metrics.hammingLoss}") + + // Subset accuracy + println(s"Subset accuracy = ${metrics.subsetAccuracy}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala new file mode 100644 index 000000000000..090444924598 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MulticlassMetricsExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MulticlassMetricsExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new MulticlassMetrics(predictionAndLabels) + + // Confusion matrix + println("Confusion matrix:") + println(metrics.confusionMatrix) + + // Overall Statistics + val precision = metrics.precision + val recall = metrics.recall // same as true positive rate + val f1Score = metrics.fMeasure + println("Summary Statistics") + println(s"Precision = $precision") + println(s"Recall = $recall") + println(s"F1 Score = $f1Score") + + // Precision by label + val labels = metrics.labels + labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) + } + + // Recall by label + labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) + } + + // False positive rate by label + labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) + } + + // F-measure by label + labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) + } + + // Weighted stats + println(s"Weighted precision: ${metrics.weightedPrecision}") + println(s"Weighted recall: ${metrics.weightedRecall}") + println(s"Weighted F1 score: ${metrics.weightedFMeasure}") + println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala new file mode 100644 index 000000000000..cffa03d5cc9f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} + +object RankingMetricsExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("RankingMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Read in the ratings data + val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) + }.cache() + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + val binarizedRatings = ratings.map(r => Rating(r.user, r.product, + if (r.rating > 0) 1.0 else 0.0)).cache() + + // Summarize ratings + val numRatings = ratings.count() + val numUsers = ratings.map(_.user).distinct().count() + val numMovies = ratings.map(_.product).distinct().count() + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + // Build the model + val numIterations = 10 + val rank = 10 + val lambda = 0.01 + val model = ALS.train(ratings, rank, numIterations, lambda) + + // Define a function to scale ratings from 0 to 1 + def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) + } + + // Get sorted top ten predictions for each user and then scale from [0, 1] + val userRecommended = model.recommendProductsForUsers(10).map { case (user, recs) => + (user, recs.map(scaledRating)) + } + + // Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document + // Compare with top ten most relevant documents + val userMovies = binarizedRatings.groupBy(_.user) + val relevantDocuments = userMovies.join(userRecommended).map { case (user, (actual, + predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) + } + + // Instantiate metrics object + val metrics = new RankingMetrics(relevantDocuments) + + // Precision at K + Array(1, 3, 5).foreach { k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") + } + + // Mean average precision + println(s"Mean average precision = ${metrics.meanAveragePrecision}") + + // Normalized discounted cumulative gain + Array(1, 3, 5).foreach { k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") + } + + // Get predictions for each data point + val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, + r.product), r.rating)) + val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) + val predictionsAndLabels = allPredictions.join(allRatings).map { case ((user, product), + (predicted, actual)) => + (predicted, actual) + } + + // Get the RMSE using regression metrics + val regressionMetrics = new RegressionMetrics(predictionsAndLabels) + println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${regressionMetrics.r2}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala new file mode 100644 index 000000000000..47d44532521c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println + +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RegressionMetricsExample { + def main(args: Array[String]) : Unit = { + val conf = new SparkConf().setAppName("RegressionMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + + // Build the model + val numIterations = 100 + val model = LinearRegressionWithSGD.train(data, numIterations) + + // Get predictions + val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) + } + + // Instantiate metrics object + val metrics = new RegressionMetrics(valuesAndPreds) + + // Squared error + println(s"MSE = ${metrics.meanSquaredError}") + println(s"RMSE = ${metrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${metrics.r2}") + + // Mean absolute error + println(s"MAE = ${metrics.meanAbsoluteError}") + + // Explained variance + println(s"Explained variance = ${metrics.explainedVariance}") + // $example off$ + } +} +// scalastyle:on println + From 58b4e4f88a330135c4cec04a30d24ef91bc61d91 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 20 Nov 2015 15:30:53 -0800 Subject: [PATCH 1485/1824] [SPARK-11787][SPARK-11883][SQL][FOLLOW-UP] Cleanup for this patch. This mainly moves SqlNewHadoopRDD to the sql package. There is some state that is shared between core and I've left that in core. This allows some other associated minor cleanup. Author: Nong Li Closes #9845 from nongli/spark-11787. --- .../org/apache/spark/rdd/HadoopRDD.scala | 6 +- .../spark/rdd/SqlNewHadoopRDDState.scala | 41 +++++++++++++ .../sql/catalyst/expressions/UnsafeRow.java | 59 ++++++++++++++---- .../catalyst/expressions/InputFileName.scala | 6 +- .../parquet/UnsafeRowParquetRecordReader.java | 14 +++++ .../scala/org/apache/spark/sql/SQLConf.scala | 5 ++ .../datasources}/SqlNewHadoopRDD.scala | 60 +++++++------------ .../datasources/parquet/ParquetRelation.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 43 ++++++------- .../datasources/parquet/ParquetIOSuite.scala | 19 ++++++ 10 files changed, 175 insertions(+), 80 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala rename {core/src/main/scala/org/apache/spark/rdd => sql/core/src/main/scala/org/apache/spark/sql/execution/datasources}/SqlNewHadoopRDD.scala (86%) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 7db583468792..f37c95bedc0a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,8 +215,8 @@ class HadoopRDD[K, V]( // Sets the thread local variable for the file's name split.inputSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -256,7 +256,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { - SqlNewHadoopRDD.unsetInputFileName() + SqlNewHadoopRDDState.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala new file mode 100644 index 000000000000..3f15fff79366 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * State for SqlNewHadoopRDD objects. This is split this way because of the package splits. + * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD + */ +private[spark] object SqlNewHadoopRDDState { + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 33769363a0ed..b6979d0c8297 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.expressions; -import java.io.*; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.io.OutputStream; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -26,12 +30,26 @@ import java.util.HashSet; import java.util.Set; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; - -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CalendarIntervalType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.NullType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.UserDefinedType; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -39,9 +57,23 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.*; +import static org.apache.spark.sql.types.DataTypes.BooleanType; +import static org.apache.spark.sql.types.DataTypes.ByteType; +import static org.apache.spark.sql.types.DataTypes.DateType; +import static org.apache.spark.sql.types.DataTypes.DoubleType; +import static org.apache.spark.sql.types.DataTypes.FloatType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.NullType; +import static org.apache.spark.sql.types.DataTypes.ShortType; +import static org.apache.spark.sql.types.DataTypes.TimestampType; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -116,11 +148,6 @@ public static boolean isMutable(DataType dt) { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - private void setNotNullAt(int i) { - assertIndexIsValid(i); - BitSetMethods.unset(baseObject, baseOffset, i); - } - /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -187,6 +214,12 @@ public void pointTo(byte[] buf, int sizeInBytes) { pointTo(buf, numFields, sizeInBytes); } + + public void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index d809877817a5..bf215783fc27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.rdd.SqlNewHadoopRDDState import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{DataType, StringType} @@ -37,13 +37,13 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { - SqlNewHadoopRDD.getInputFileName() + SqlNewHadoopRDDState.getInputFileName() } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 8a92e489ccb7..dade488ca281 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -108,6 +108,19 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas */ private static final int DEFAULT_VAR_LEN_SIZE = 32; + /** + * Tries to initialize the reader for this split. Returns true if this reader supports reading + * this split and false otherwise. + */ + public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) { + try { + initialize(inputSplit, taskAttemptContext); + return true; + } catch (Exception e) { + return false; + } + } + /** * Implementation of RecordReader API. */ @@ -326,6 +339,7 @@ private void decodeBinaryBatch(int col, int num) throws IOException { } else { rowWriters[n].write(col, bytes.array(), bytes.position(), len); } + rows[n].setNotNullAt(col); } else { rows[n].setNullAt(col); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f40e603cd193..5ef3a48c56a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -323,6 +323,11 @@ private[spark] object SQLConf { "option must be set in Hadoop Configuration. 2. This option overrides " + "\"spark.sql.sources.outputCommitterClass\".") + val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = booleanConf( + key = "spark.sql.parquet.enableUnsafeRowRecordReader", + defaultValue = Some(true), + doc = "Enables using the custom ParquetUnsafeRowRecordReader.") + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), doc = "When true, enable filter pushdown for ORC files.") diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala similarity index 86% rename from core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 4d176332b69c..56cb63d9eff2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -28,13 +30,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import scala.reflect.ClassTag - private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -61,13 +62,13 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ private[spark] class SqlNewHadoopRDD[V: ClassTag]( - sc : SparkContext, + sqlContext: SQLContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[V](sc, Nil) + extends RDD[V](sqlContext.sparkContext, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -99,7 +100,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). protected val enableUnsafeRowParquetReader: Boolean = - sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) @@ -120,8 +121,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } override def compute( - theSplit: SparkPartition, - context: TaskContext): Iterator[V] = { + theSplit: SparkPartition, + context: TaskContext): Iterator[V] = { val iter = new Iterator[V] { val split = theSplit.asInstanceOf[SqlNewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) @@ -132,8 +133,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -163,15 +164,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( * TODO: plumb this through a different way? */ if (enableUnsafeRowParquetReader && - format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { - // TODO: move this class to sql.execution and remove this. - reader = Utils.classForName( - "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") - .newInstance().asInstanceOf[RecordReader[Void, V]] - try { - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - } catch { - case e: Exception => reader = null + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + val parquetReader: UnsafeRowParquetRecordReader = new UnsafeRowParquetRecordReader() + if (!parquetReader.tryInitialize( + split.serializableHadoopSplit.value, hadoopAttemptContext)) { + parquetReader.close() + } else { + reader = parquetReader.asInstanceOf[RecordReader[Void, V]] } } @@ -217,7 +216,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( private def close() { if (reader != null) { - SqlNewHadoopRDD.unsetInputFileName() + SqlNewHadoopRDDState.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic @@ -235,7 +234,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { @@ -276,23 +275,6 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } super.persist(storageLevel) } -} - -private[spark] object SqlNewHadoopRDD { - - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index cb0aab8cc0d0..fdd745f48e97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -319,7 +319,7 @@ private[sql] class ParquetRelation( Utils.withDummyCallSite(sqlContext.sparkContext) { new SqlNewHadoopRDD( - sc = sqlContext.sparkContext, + sqlContext = sqlContext, broadcastedConf = broadcastedConf, initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c8028a5ef552..cc5aae03d551 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,29 +337,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does - // not do row by row filtering (and we probably don't want to push that). - ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // The unsafe row RecordReader does not support row by row filtering so run it with it disabled. + test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withTempPath { dir => - val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - val df = sqlContext.read.parquet(path).filter("a = 2") - - // This is the source RDD without Spark-side filtering. - val childRDD = - df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - - // The result should be single row. - // When a filter is pushed to Parquet, Parquet can apply it to every row. - // So, we can check the number of rows returned from the Parquet - // to make sure our filter pushdown work. - assert(childRDD.count == 1) + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + val df = sqlContext.read.parquet(path).filter("a = 2") + + // This is the source RDD without Spark-side filtering. + val childRDD = + df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + + // The result should be single row. + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + assert(childRDD.count == 1) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 177ab42f7767..0c5d4887ed79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -579,6 +579,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("null and non-null strings") { + // Create a dataset where the first values are NULL and then some non-null values. The + // number of non-nulls needs to be bigger than the ParquetReader batch size. + val data = sqlContext.range(200).map { i => + if (i.getLong(0) < 150) Row(None) + else Row("a") + } + val df = sqlContext.createDataFrame(data, StructType(StructField("col", StringType) :: Nil)) + assert(df.agg("col" -> "count").collect().head.getLong(0) == 50) + + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/data" + df.write.parquet(path) + + val df2 = sqlContext.read.parquet(path) + assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50) + } + } + test("read dictionary encoded decimals written as INT32") { checkAnswer( // Decimal column in this file is encoded using plain dictionary From 968acf3bd9a502fcad15df3e53e359695ae702cc Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:36:30 -0800 Subject: [PATCH 1486/1824] [SPARK-11889][SQL] Fix type inference for GroupedDataset.agg in REPL In this PR I delete a method that breaks type inference for aggregators (only in the REPL) The error when this method is present is: ``` :38: error: missing parameter type for expanded function ((x$2) => x$2._2) ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() ``` Author: Michael Armbrust Closes #9870 from marmbrus/dataset-repl-agg. --- .../org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++ .../org/apache/spark/sql/GroupedDataset.scala | 27 +++---------------- .../apache/spark/sql/JavaDatasetSuite.java | 8 +++--- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 081aa03002cc..cbcccb11f14a 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite { } } + test("Datasets agg type-inference") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |/** An `Aggregator` that adds up any numeric type returned by the given function. */ + |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + | val numeric = implicitly[Numeric[N]] + | override def zero: N = numeric.zero + | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) + | override def finish(reduction: N): N = reduction + |} + | + |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn + |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() + |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 6de3dd626576..263f04910476 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql]( reduce(f.call _) } - /** - * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]]. - * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * We can also use `Aggregator.toColumn` to pass in typed aggregate functions. - * - * @since 1.6.0 - */ + // This is here to prevent us from adding overloads that would be ambiguous. @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = - groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*) + private def agg(exprs: Column*): DataFrame = + groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ce40dd856f67..f7249b8945c4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -404,11 +404,9 @@ public String call(Tuple2 value) throws Exception { grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - Dataset> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()), - expr("sum(_2)"), - count("*")) - .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); + Dataset> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( new Tuple4<>("a", 3, 3L, 2L), From 68ed046836975b492b594967256d3c7951b568a5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:38:04 -0800 Subject: [PATCH 1487/1824] [SPARK-11890][SQL] Fix compilation for Scala 2.11 Author: Michael Armbrust Closes #9871 from marmbrus/scala211-break. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 918050b531c0..4a4a62ed1a46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -670,14 +670,14 @@ trait ScalaReflection { * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return * `NullType` silently instead. */ - private def silentSchemaFor(tpe: `Type`): Schema = try { + protected def silentSchemaFor(tpe: `Type`): Schema = try { schemaFor(tpe) } catch { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns the full class name for a type. */ - private def getClassNameFromType(tpe: `Type`): String = { + protected def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } From 47815878ad5e47e89bfbd57acb848be2ce67a4a5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 16:02:03 -0800 Subject: [PATCH 1488/1824] [HOTFIX] Fix Java Dataset Tests --- .../test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f7249b8945c4..f32374b4c04d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -409,8 +409,8 @@ public String call(Tuple2 value) throws Exception { .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( - new Tuple4<>("a", 3, 3L, 2L), - new Tuple4<>("b", 3, 3L, 1L)), + new Tuple2<>("a", 3), + new Tuple2<>("b", 3)), agged2.collectAsList()); } From a2dce22e0a25922e2052318d32f32877b7c27ec2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 20 Nov 2015 16:51:47 -0800 Subject: [PATCH 1489/1824] Revert "[SPARK-11689][ML] Add user guide and example code for LDA under spark.ml" This reverts commit e359d5dcf5bd300213054ebeae9fe75c4f7eb9e7. --- docs/ml-clustering.md | 30 ------ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 - .../spark/examples/ml/JavaLDAExample.java | 94 ------------------- .../apache/spark/examples/ml/LDAExample.scala | 77 --------------- 5 files changed, 1 insertion(+), 204 deletions(-) delete mode 100644 docs/ml-clustering.md delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md deleted file mode 100644 index 1743ef43a6dd..000000000000 --- a/docs/ml-clustering.md +++ /dev/null @@ -1,30 +0,0 @@ ---- -layout: global -title: Clustering - ML -displayTitle: ML - Clustering ---- - -In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). - -## Latent Dirichlet allocation (LDA) - -`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, -and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by -`EMLDAOptimizer` to a `DistributedLDAModel` if needed. - -
    - -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. - -
    -{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} -
    - -
    - -Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. - -{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} -
    - -
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 6f35b30c3d4d..be18a05361a1 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,7 +40,6 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) -* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -951,4 +950,4 @@ model.transform(test) {% endhighlight %} - \ No newline at end of file + diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 54e35fcbb15a..91e50ccfecec 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,7 +69,6 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) -* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java deleted file mode 100644 index b3a7d2eb2978..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.ml.clustering.LDA; -import org.apache.spark.ml.clustering.LDAModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -/** - * An example demonstrating LDA - * Run with - *
    - * bin/run-example ml.JavaLDAExample
    - * 
    - */ -public class JavaLDAExample { - - private static class ParseVector implements Function { - private static final Pattern separator = Pattern.compile(" "); - - @Override - public Row call(String line) { - String[] tok = separator.split(line); - double[] point = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - point[i] = Double.parseDouble(tok[i]); - } - Vector[] points = {Vectors.dense(point)}; - return new GenericRow(points); - } - } - - public static void main(String[] args) { - - String inputFile = "data/mllib/sample_lda_data.txt"; - - // Parses the arguments - SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // Loads data - JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); - StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; - StructType schema = new StructType(fields); - DataFrame dataset = sqlContext.createDataFrame(points, schema); - - // Trains a LDA model - LDA lda = new LDA() - .setK(10) - .setMaxIter(10); - LDAModel model = lda.fit(dataset); - - System.out.println(model.logLikelihood(dataset)); - System.out.println(model.logPerplexity(dataset)); - - // Shows the result - DataFrame topics = model.describeTopics(3); - topics.show(false); - model.transform(dataset).show(false); - - jsc.stop(); - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala deleted file mode 100644 index 419ce3d87a6a..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml - -// scalastyle:off println -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -// $example on$ -import org.apache.spark.ml.clustering.LDA -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} -// $example off$ - -/** - * An example demonstrating a LDA of ML pipeline. - * Run with - * {{{ - * bin/run-example ml.LDAExample - * }}} - */ -object LDAExample { - - final val FEATURES_COL = "features" - - def main(args: Array[String]): Unit = { - - val input = "data/mllib/sample_lda_data.txt" - // Creates a Spark context and a SQL context - val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) - - // Trains a LDA model - val lda = new LDA() - .setK(10) - .setMaxIter(10) - .setFeaturesCol(FEATURES_COL) - val model = lda.fit(dataset) - val transformed = model.transform(dataset) - - val ll = model.logLikelihood(dataset) - val lp = model.logPerplexity(dataset) - - // describeTopics - val topics = model.describeTopics(3) - - // Shows the result - topics.show(false) - transformed.show(false) - - // $example off$ - sc.stop() - } -} -// scalastyle:on println From 7d3f922c4ba76c4193f98234ae662065c39cdfb1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2015 23:31:19 -0800 Subject: [PATCH 1490/1824] [SPARK-11819][SQL][FOLLOW-UP] fix scala 2.11 build seems scala 2.11 doesn't support: define private methods in `trait xxx` and use it in `object xxx extend xxx`. Author: Wenchen Fan Closes #9879 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4a4a62ed1a46..476becec4dd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -670,14 +670,14 @@ trait ScalaReflection { * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return * `NullType` silently instead. */ - protected def silentSchemaFor(tpe: `Type`): Schema = try { + def silentSchemaFor(tpe: `Type`): Schema = try { schemaFor(tpe) } catch { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns the full class name for a type. */ - protected def getClassNameFromType(tpe: `Type`): String = { + def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } From 54328b6d862fe62ae01bdd87df4798ceb9d506d6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 00:10:13 -0800 Subject: [PATCH 1491/1824] [SPARK-11900][SQL] Add since version for all encoders Author: Reynold Xin Closes #9881 from rxin/SPARK-11900. --- .../scala/org/apache/spark/sql/Encoder.scala | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 86bb53645903..5cb8edf64e87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -45,13 +45,52 @@ trait Encoder[T] extends Serializable { */ object Encoders { + /** + * An encoder for nullable boolean type. + * @since 1.6.0 + */ def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + + /** + * An encoder for nullable byte type. + * @since 1.6.0 + */ def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + + /** + * An encoder for nullable short type. + * @since 1.6.0 + */ def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + + /** + * An encoder for nullable int type. + * @since 1.6.0 + */ def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + + /** + * An encoder for nullable long type. + * @since 1.6.0 + */ def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + + /** + * An encoder for nullable float type. + * @since 1.6.0 + */ def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + + /** + * An encoder for nullable double type. + * @since 1.6.0 + */ def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + + /** + * An encoder for nullable string type. + * @since 1.6.0 + */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** @@ -59,6 +98,8 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * T must be publicly accessible. + * + * @since 1.6.0 */ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) @@ -67,6 +108,8 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * T must be publicly accessible. + * + * @since 1.6.0 */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) @@ -77,6 +120,8 @@ object Encoders { * Note that this is extremely inefficient and should only be used as the last resort. * * T must be publicly accessible. + * + * @since 1.6.0 */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -87,6 +132,8 @@ object Encoders { * Note that this is extremely inefficient and should only be used as the last resort. * * T must be publicly accessible. + * + * @since 1.6.0 */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) @@ -120,12 +167,20 @@ object Encoders { ) } + /** + * An encoder for 2-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2]( e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) } + /** + * An encoder for 3-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3]( e1: Encoder[T1], e2: Encoder[T2], @@ -133,6 +188,10 @@ object Encoders { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) } + /** + * An encoder for 4-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3, T4]( e1: Encoder[T1], e2: Encoder[T2], @@ -141,6 +200,10 @@ object Encoders { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) } + /** + * An encoder for 5-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3, T4, T5]( e1: Encoder[T1], e2: Encoder[T2], From 596710268e29e8f624c3ba2fade08b66ec7084eb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 00:54:18 -0800 Subject: [PATCH 1492/1824] [SPARK-11901][SQL] API audit for Aggregator. Author: Reynold Xin Closes #9882 from rxin/SPARK-11901. --- .../scala/org/apache/spark/sql/Dataset.scala | 1 - .../spark/sql/expressions/Aggregator.scala | 39 ++++++++++++------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bdcdc5d47cba..07647508421a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 72610e735f78..b0cd32b5f73e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} /** * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] @@ -32,55 +31,65 @@ import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} * case class Data(i: Int) * * val customSummer = new Aggregator[Data, Int, Int] { - * def zero = 0 - * def reduce(b: Int, a: Data) = b + a.i - * def present(r: Int) = r + * def zero: Int = 0 + * def reduce(b: Int, a: Data): Int = b + a.i + * def merge(b1: Int, b2: Int): Int = b1 + b2 + * def present(r: Int): Int = r * }.toColumn() * - * val ds: Dataset[Data] + * val ds: Dataset[Data] = ... * val aggregated = ds.select(customSummer) * }}} * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * - * @tparam A The input type for the aggregation. + * @tparam I The input type for the aggregation. * @tparam B The type of the intermediate value of the reduction. - * @tparam C The type of the final result. + * @tparam O The type of the final output result. + * + * @since 1.6.0 */ -abstract class Aggregator[-A, B, C] extends Serializable { +abstract class Aggregator[-I, B, O] extends Serializable { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + /** + * A zero value for this aggregation. Should satisfy the property that any b + zero = b. + * @since 1.6.0 + */ def zero: B /** * Combine two values to produce a new value. For performance, the function may modify `b` and * return it instead of constructing new object for b. + * @since 1.6.0 */ - def reduce(b: B, a: A): B + def reduce(b: B, a: I): B /** - * Merge two intermediate values + * Merge two intermediate values. + * @since 1.6.0 */ def merge(b1: B, b2: B): B /** * Transform the output of the reduction. + * @since 1.6.0 */ - def finish(reduction: B): C + def finish(reduction: B): O /** * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] * operations. + * @since 1.6.0 */ def toColumn( implicit bEncoder: Encoder[B], - cEncoder: Encoder[C]): TypedColumn[A, C] = { + cEncoder: Encoder[O]): TypedColumn[I, O] = { val expr = new AggregateExpression( TypedAggregateExpression(this), Complete, false) - new TypedColumn[A, C](expr, encoderFor[C]) + new TypedColumn[I, O](expr, encoderFor[O]) } } From ff442bbcffd4f93cfcc2f76d160011e725d2fb3f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 15:00:37 -0800 Subject: [PATCH 1493/1824] [SPARK-11899][SQL] API audit for GroupedDataset. 1. Renamed map to mapGroup, flatMap to flatMapGroup. 2. Renamed asKey -> keyAs. 3. Added more documentation. 4. Changed type parameter T to V on GroupedDataset. 5. Added since versions for all functions. Author: Reynold Xin Closes #9880 from rxin/SPARK-11899. --- .../api/java/function/MapGroupFunction.java | 2 +- .../scala/org/apache/spark/sql/Encoder.scala | 4 + .../sql/catalyst/JavaTypeInference.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 2 + .../org/apache/spark/sql/DataFrame.scala | 1 - .../org/apache/spark/sql/GroupedDataset.scala | 132 ++++++++++++++---- .../apache/spark/sql/JavaDatasetSuite.java | 8 +- .../spark/sql/DatasetPrimitiveSuite.scala | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 20 +-- 9 files changed, 131 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java index 2935f9986a56..4f3f222e064b 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java @@ -21,7 +21,7 @@ import java.util.Iterator; /** - * Base interface for a map function used in GroupedDataset's map function. + * Base interface for a map function used in GroupedDataset's mapGroup function. */ public interface MapGroupFunction extends Serializable { R call(K key, Iterator values) throws Exception; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 5cb8edf64e87..03aa25eda807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.types._ * * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking * and reuse internal buffers to improve performance. + * + * @since 1.6.0 */ trait Encoder[T] extends Serializable { @@ -42,6 +44,8 @@ trait Encoder[T] extends Serializable { /** * Methods for creating encoders. + * + * @since 1.6.0 */ object Encoders { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 88a457f87ce4..7d4cfbe6faec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ /** * Type-inference utilities for POJOs and Java collections. */ -private [sql] object JavaTypeInference { +object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) @@ -53,7 +53,6 @@ private [sql] object JavaTypeInference { * @return (SQL data type, nullable) */ private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 82e9cd7f50a3..30c554a85e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -46,6 +46,8 @@ private[sql] object Column { * @tparam T The input type expected for this expression. Can be `Any` if the expression is type * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. + * + * @since 1.6.0 */ class TypedColumn[-T, U]( expr: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7abcecaa2880..5586fc994b98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -110,7 +110,6 @@ private[sql] object DataFrame { * @groupname action Actions * @since 1.3.0 */ -// TODO: Improve documentation. @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 263f04910476..7f43ce16901b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Ou import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.Aggregator /** * :: Experimental :: @@ -36,11 +37,13 @@ import org.apache.spark.sql.execution.QueryExecution * making this change to the class hierarchy would break some function signatures. As such, this * class should be considered a preview of the final API. Changes will be made to the interface * after Spark 1.6. + * + * @since 1.6.0 */ @Experimental -class GroupedDataset[K, T] private[sql]( +class GroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], - tEncoder: Encoder[T], + tEncoder: Encoder[V], val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { @@ -67,8 +70,10 @@ class GroupedDataset[K, T] private[sql]( /** * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. + * + * @since 1.6.0 */ - def asKey[L : Encoder]: GroupedDataset[L, T] = + def keyAs[L : Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], unresolvedTEncoder, @@ -78,6 +83,8 @@ class GroupedDataset[K, T] private[sql]( /** * Returns a [[Dataset]] that contains each unique key. + * + * @since 1.6.0 */ def keys: Dataset[K] = { new Dataset[K]( @@ -92,12 +99,18 @@ class GroupedDataset[K, T] private[sql]( * function can return an iterator containing elements of an arbitrary type which will be returned * as a new [[Dataset]]. * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. + * + * @since 1.6.0 */ - def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( @@ -108,8 +121,25 @@ class GroupedDataset[K, T] private[sql]( logicalPlan)) } - def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { - flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder) } /** @@ -117,32 +147,62 @@ class GroupedDataset[K, T] private[sql]( * be passed the group key and an iterator that contains all of the elements in the group. The * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. + * + * @since 1.6.0 */ - def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { - val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) - flatMap(func) + def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) + flatMapGroup(func) } - def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { - map((key, data) => f.call(key, data.asJava))(encoder) + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroup((key, data) => f.call(key, data.asJava))(encoder) } /** * Reduces the elements of each group of data using the specified binary function. * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 */ - def reduce(f: (T, T) => T): Dataset[(K, T)] = { - val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f)) + def reduce(f: (V, V) => V): Dataset[(K, V)] = { + val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) - flatMap(func) + flatMapGroup(func) } - def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = { + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = { reduce(f.call _) } @@ -185,41 +245,51 @@ class GroupedDataset[K, T] private[sql]( /** * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key * and the result of computing this aggregation over all elements in the group. + * + * @since 1.6.0 */ - def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] = + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ - def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] = + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ def agg[U1, U2, U3]( - col1: TypedColumn[T, U1], - col2: TypedColumn[T, U2], - col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] = + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ def agg[U1, U2, U3, U4]( - col1: TypedColumn[T, U1], - col2: TypedColumn[T, U2], - col3: TypedColumn[T, U3], - col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] = + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. + * + * @since 1.6.0 */ def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) @@ -228,10 +298,12 @@ class GroupedDataset[K, T] private[sql]( * be passed the grouping key and 2 iterators containing all elements in the group from * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 */ def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit def uEnc: Encoder[U] = other.unresolvedTEncoder new Dataset[R]( sqlContext, @@ -243,9 +315,17 @@ class GroupedDataset[K, T] private[sql]( other.logicalPlan)) } + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ def cogroup[U, R]( other: GroupedDataset[K, U], - f: CoGroupFunction[K, T, U, R], + f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f32374b4c04d..cf335efdd23b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,7 +170,7 @@ public Integer call(String v) throws Exception { } }, Encoders.INT()); - Dataset mapped = grouped.map(new MapGroupFunction() { + Dataset mapped = grouped.mapGroup(new MapGroupFunction() { @Override public String call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -183,7 +183,7 @@ public String call(Integer key, Iterator values) throws Exception { Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); - Dataset flatMapped = grouped.flatMap( + Dataset flatMapped = grouped.flatMapGroup( new FlatMapGroupFunction() { @Override public Iterable call(Integer key, Iterator values) throws Exception { @@ -247,9 +247,9 @@ public void testGroupByColumn() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, Encoders.STRING()); GroupedDataset grouped = - ds.groupBy(length(col("value"))).asKey(Encoders.INT()); + ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); - Dataset mapped = grouped.map( + Dataset mapped = grouped.mapGroup( new MapGroupFunction() { @Override public String call(Integer key, Iterator data) throws Exception { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 63b00975e4eb..d387710357be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.map { case (g, iter) => + val agged = grouped.mapGroup { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } @@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() val grouped = ds.groupBy(_.length) - val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) } + val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) } checkAnswer( agged, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 89d964aa3e46..9da02550b39c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) } + val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, @@ -234,7 +234,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } + val agged = grouped.flatMapGroup { case (g, iter) => + Iterator(g._1, iter.map(_._2).sum.toString) + } checkAnswer( agged, @@ -253,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } + val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, @@ -262,8 +264,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -272,8 +274,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -282,8 +284,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From 426004a9c9a864f90494d08601e6974709091a56 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 22 Nov 2015 10:36:47 -0800 Subject: [PATCH 1494/1824] [SPARK-11908][SQL] Add NullType support to RowEncoder JIRA: https://issues.apache.org/jira/browse/SPARK-11908 We should add NullType support to RowEncoder. Author: Liang-Chi Hsieh Closes #9891 from viirya/rowencoder-nulltype. --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 5 +++-- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 3 +++ .../apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 4cda4824acdc..fa553e7c5324 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -48,7 +48,7 @@ object RowEncoder { private def extractorsFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => inputObject case udt: UserDefinedType[_] => @@ -143,6 +143,7 @@ object RowEncoder { case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) case udt: UserDefinedType[_] => ObjectType(udt.userClass) + case _: NullType => ObjectType(classOf[java.lang.Object]) } private def constructorFor(schema: StructType): Expression = { @@ -158,7 +159,7 @@ object RowEncoder { } private def constructorFor(input: Expression): Expression = input.dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input case udt: UserDefinedType[_] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index ef7399e0196a..82317d338516 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -369,6 +369,9 @@ case class MapObjects( private lazy val completeFunction = function(loopAttribute) private def itemAccessorMethod(dataType: DataType): String => String = dataType match { + case NullType => + val nullTypeClassName = NullType.getClass.getName + ".MODULE$" + (i: String) => s".get($i, $nullTypeClassName)" case IntegerType => (i: String) => s".getInt($i)" case LongType => (i: String) => s".getLong($i)" case FloatType => (i: String) => s".getFloat($i)" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 46c6e0d98d34..0ea51ece4bc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -80,11 +80,13 @@ class RowEncoderSuite extends SparkFunSuite { private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) private val mapOfString = MapType(StringType, StringType) private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) encodeDecodeTest( new StructType() + .add("null", NullType) .add("boolean", BooleanType) .add("byte", ByteType) .add("short", ShortType) @@ -101,6 +103,7 @@ class RowEncoderSuite extends SparkFunSuite { encodeDecodeTest( new StructType() + .add("arrayOfNull", arrayOfNull) .add("arrayOfString", arrayOfString) .add("arrayOfArrayOfString", ArrayType(arrayOfString)) .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) From fe89c1817d668e46adf70d0896c42c22a547c76a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 22 Nov 2015 21:45:46 -0800 Subject: [PATCH 1495/1824] [SPARK-11895][ML] rename and refactor DatasetExample under mllib/examples We used the name `Dataset` to refer to `SchemaRDD` in 1.2 in ML pipelines and created this example file. Since `Dataset` has a new meaning in Spark 1.6, we should rename it to avoid confusion. This PR also removes support for dense format to simplify the example code. cc: yinxusen Author: Xiangrui Meng Closes #9873 from mengxr/SPARK-11895. --- .../DataFrameExample.scala} | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) rename examples/src/main/scala/org/apache/spark/examples/{mllib/DatasetExample.scala => ml/DataFrameExample.scala} (51%) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala similarity index 51% rename from examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index dc13f82488af..424f00158c2f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -16,7 +16,7 @@ */ // scalastyle:off println -package org.apache.spark.examples.mllib +package org.apache.spark.examples.ml import java.io.File @@ -24,25 +24,22 @@ import com.google.common.io.Files import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * ./bin/run-example ml.DataFrameExample [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object DatasetExample { +object DataFrameExample { - case class Params( - input: String = "data/mllib/sample_libsvm_data.txt", - dataFormat: String = "libsvm") extends AbstractParams[Params] + case class Params(input: String = "data/mllib/sample_libsvm_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -52,9 +49,6 @@ object DatasetExample { opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) - opt[String]("dataFormat") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(input = x)) checkConfig { params => success } @@ -69,55 +63,42 @@ object DatasetExample { def run(params: Params) { - val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val conf = new SparkConf().setAppName(s"DataFrameExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // for implicit conversions // Load input data - val origData: RDD[LabeledPoint] = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) - } - println(s"Loaded ${origData.count()} instances from file: ${params.input}") - - // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDF() - println(s"Inferred schema:\n${df.schema.prettyJson}") - println(s"Converted to DataFrame with ${df.count()} records") - - // Select columns - val labelsDf: DataFrame = df.select("label") - val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } - val numLabels = labels.count() - val meanLabel = labels.fold(0.0)(_ + _) / numLabels - println(s"Selected label column with average value $meanLabel") - - val featuresDf: DataFrame = df.select("features") - val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } + println(s"Loading LIBSVM file with UDT from ${params.input}.") + val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() + println("Schema from LIBSVM:") + df.printSchema() + println(s"Loaded training data as a DataFrame with ${df.count()} records.") + + // Show statistical summary of labels. + val labelSummary = df.describe("label") + labelSummary.show() + + // Convert features column to an RDD of vectors. + val features = df.select("features").map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + // Save the records in a parquet file. val tmpDir = Files.createTempDir() tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) + // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.read.parquet(outputDir) - - println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } - val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + val newDF = sqlContext.read.parquet(outputDir) + println(s"Schema from Parquet:") + newDF.printSchema() sc.stop() } - } // scalastyle:on println From a6fda0bfc16a13b28b1cecc96f1ff91363089144 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 22 Nov 2015 21:48:48 -0800 Subject: [PATCH 1496/1824] [SPARK-6791][ML] Add read/write for CrossValidator and Evaluators I believe this works for general estimators within CrossValidator, including compound estimators. (See the complex unit test.) Added read/write for all 3 Evaluators as well. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9848 from jkbradley/cv-io. --- .../scala/org/apache/spark/ml/Pipeline.scala | 38 +-- .../BinaryClassificationEvaluator.scala | 11 +- .../MulticlassClassificationEvaluator.scala | 12 +- .../ml/evaluation/RegressionEvaluator.scala | 11 +- .../apache/spark/ml/recommendation/ALS.scala | 14 +- .../spark/ml/tuning/CrossValidator.scala | 229 +++++++++++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 48 ++-- .../org/apache/spark/ml/PipelineSuite.scala | 4 +- .../BinaryClassificationEvaluatorSuite.scala | 13 +- ...lticlassClassificationEvaluatorSuite.scala | 13 +- .../evaluation/RegressionEvaluatorSuite.scala | 12 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 202 ++++++++++++++- 12 files changed, 522 insertions(+), 85 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 6f15b37abcb3..4b2b3f8489fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] { stages: Array[PipelineStage], sc: SparkContext, path: String): Unit = { - // Copied and edited from DefaultParamsWriter.saveMetadata - // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication - val uid = instance.uid - val cls = instance.getClass.getName val stageUids = stages.map(_.uid) val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) - val metadata = ("class" -> cls) ~ - ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ - ("uid" -> uid) ~ - ("paramMap" -> jsonParams) - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) // Save stages val stagesDir = new Path(path, "stages").toString @@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] { implicit val format = DefaultFormats val stagesDir = new Path(path, "stages").toString - val stageUids: Array[String] = metadata.params match { - case JObject(pairs) => - if (pairs.length != 1) { - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException( - s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") - } - pairs.head match { - case ("stageUids", jsonValue) => - jsonValue.extract[Seq[String]].toArray - case (paramName, jsonValue) => - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + - s" in metadata: ${metadata.metadataStr}") - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") - } + val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) - val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) - val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) + DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc) } (metadata.uid, stages) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1fe3abaca81c..bfb70963b151 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.2.0") @Experimental class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol { + extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) @@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.4.1") override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): BinaryClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index df5f04ca5a8d..c44db0ec595e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.types.DoubleType @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object MulticlassClassificationEvaluator + extends DefaultParamsReadable[MulticlassClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): MulticlassClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index ba012f444d3e..daaa174a086e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.5.0") override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] { + + @Since("1.6.0") + override def load(path: String): RegressionEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4d35177ad9b0..b798aa1fab76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} -import org.json4s.{DefaultFormats, JValue} +import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} @@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] { private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - val extraMetadata = render("rank" -> instance.rank) + val extraMetadata = "rank" -> instance.rank DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val userPath = new Path(path, "userFactors").toString instance.userFactors.write.format("parquet").save(userPath) @@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] { override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) implicit val format = DefaultFormats - val rank: Int = metadata.extraMetadata match { - case Some(m: JValue) => - (m \ "rank").extract[Int] - case None => - throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + - s" ${metadata.metadataStr}") - } - + val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString val userFactors = sqlContext.read.format("parquet").load(userPath) val itemPath = new Path(path, "itemFactors").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 77d9948ed86b..83a904837426 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,17 +18,24 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS +import org.apache.hadoop.fs.Path +import org.json4s.{JObject, DefaultFormats} +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.classification.OneVsRestParams +import org.apache.spark.ml.feature.RFormulaModel +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType + /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ @@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { */ @Experimental class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with Logging { + with CrossValidatorParams with MLWritable with Logging { def this() = this(Identifiable.randomUID("cv")) @@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } copied } + + // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types. + // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]]. + // However, this case should be unusual. + @Since("1.6.0") + override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this) +} + +@Since("1.6.0") +object CrossValidator extends MLReadable[CrossValidator] { + + @Since("1.6.0") + override def read: MLReader[CrossValidator] = new CrossValidatorReader + + @Since("1.6.0") + override def load(path: String): CrossValidator = super.load(path) + + private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(path, instance, sc) + } + + private class CrossValidatorReader extends MLReader[CrossValidator] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidator].getName + + override def load(path: String): CrossValidator = { + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + new CrossValidator(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setNumFolds(numFolds) + } + } + + private object CrossValidatorReader { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException("CrossValidator.load found a compound estimator with stages" + + s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}") + } + uidMap + } + + def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRestParams => + // TODO: SPARK-11892: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") + case rform: RFormulaModel => + // TODO: SPARK-11891: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing an RFormulaModel") + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } + } + + private[tuning] object SharedReadWrite { + + /** + * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable. + * This does not check [[CrossValidator.estimatorParamMaps]]. + */ + def validateParams(instance: ValidatorParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("CrossValidator write will fail " + + s" because it contains $name which does not implement Writable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + checkElement(instance.getEvaluator, "evaluator") + checkElement(instance.getEstimator, "estimator") + // Check to make sure all Params apply to this estimator. Throw an error if any do not. + // Extraneous Params would cause problems when loading the estimatorParamMaps. + val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance) + instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" + + s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } + } + } + + private[tuning] def saveImpl( + path: String, + instance: CrossValidatorParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + import org.json4s.JsonDSL._ + + val estimatorParamMapsJson = compact(render( + instance.getEstimatorParamMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + val jsonParams = List( + "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)), + "estimatorParamMaps" -> parse(estimatorParamMapsJson) + ) + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val evaluatorPath = new Path(path, "evaluator").toString + instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) + val estimatorPath = new Path(path, "estimator").toString + instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + } + + private[tuning] def load[M <: Model[M]]( + path: String, + sc: SparkContext, + expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val evaluatorPath = new Path(path, "evaluator").toString + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val estimatorPath = new Path(path, "estimator").toString + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + + val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator) + + val numFolds = (metadata.params \ "numFolds").extract[Int] + val estimatorParamMaps: Array[ParamMap] = + (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { case pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + (metadata, estimator, evaluator, estimatorParamMaps, numFolds) + } + } } /** @@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM * * @param bestModel The best model selected from k-fold cross validation. * @param avgMetrics Average cross-validation metrics for each paramMap in - * [[estimatorParamMaps]], in the corresponding order. + * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ @Experimental class CrossValidatorModel private[ml] ( override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double]) - extends Model[CrossValidatorModel] with CrossValidatorParams { + extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { override def validateParams(): Unit = { bestModel.validateParams() @@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] ( avgMetrics.clone()) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) +} + +@Since("1.6.0") +object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + + import CrossValidator.SharedReadWrite + + @Since("1.6.0") + override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader + + @Since("1.6.0") + override def load(path: String): CrossValidatorModel = super.load(path) + + private[CrossValidatorModel] + class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidatorModel].getName + + override def load(path: String): CrossValidatorModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + cv.set(cv.estimator, estimator) + .set(cv.evaluator, evaluator) + .set(cv.estimatorParamMaps, estimatorParamMaps) + .set(cv.numFolds, numFolds) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ff9322dba122..8484b1f80106 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid - * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. + * - paramMap + * - (optionally, extra metadata) + * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + * @param paramMap If given, this is saved in the "paramMap" field. + * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using + * [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ def saveMetadata( instance: Params, path: String, sc: SparkContext, - extraMetadata: Option[JValue] = None): Unit = { + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] - val jsonParams = params.map { case ParamPair(p, v) => + val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) - }.toList - val metadata = ("class" -> cls) ~ + }.toList)) + val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) ~ - ("extraMetadata" -> extraMetadata) + ("paramMap" -> jsonParams) + val metadata = extraMetadata match { + case Some(jObject) => + basicMetadata ~ jObject + case None => + basicMetadata + } val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] - * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] - * @param metadataStr Full metadata file String (for debugging) + * @param metadata All metadata, including the other fields + * @param metadataJson Full metadata file String (for debugging) */ case class Metadata( className: String, @@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, - extraMetadata: Option[JValue], - metadataStr: String) + metadata: JValue, + metadataJson: String) /** * Load metadata from file. @@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" - val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) } /** @@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader { } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } } + + /** + * Load a [[Params]] instance from the given path, and return it. + * This assumes the instance implements [[MLReadable]]. + */ + def loadParamsInstance[T](path: String, sc: SparkContext): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 12aba6bc6dbe..8c8676745636 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.ml -import java.io.File - import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index def869fe6677..a535c1218ecf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class BinaryClassificationEvaluatorSuite extends SparkFunSuite { +class BinaryClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } + + test("read/write") { + val evaluator = new BinaryClassificationEvaluator() + .setRawPredictionCol("myRawPrediction") + .setLabelCol("myLabel") + .setMetricName("areaUnderPR") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 6d8412b0b370..7ee65975d22f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { +class MulticlassClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new MulticlassClassificationEvaluator) } + + test("read/write") { + val evaluator = new MulticlassClassificationEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("recall") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index aa722da32393..60886bf77d2f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RegressionEvaluator) @@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext evaluator.setMetricName("mae") assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) } + + test("read/write") { + val evaluator = new RegressionEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("r2") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index cbe09292a033..dd6366050c02 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,19 +18,22 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.{Pipeline, Estimator, Model} +import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamPair, ParamMap} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class CrossValidatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { } test("validateParams should check estimatorParamMaps") { - import CrossValidatorSuite._ + import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") val eval = new MyEvaluator @@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { cv.validateParams() } } + + test("read/write: CrossValidator with simple estimator") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: CrossValidator with complex estimator") { + // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]] + val lrEvaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + + val lr = new LogisticRegression().setMaxIter(3) + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val lrcv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(lrEvaluator) + .setEstimatorParamMaps(lrParamMaps) + + val hashingTF = new HashingTF() + val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv)) + val paramMaps = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 20)) + .addGrid(lr.elasticNetParam, Array(0.0, 1.0)) + .build() + val evaluator = new BinaryClassificationEvaluator() + + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.getEstimator match { + case pipeline2: Pipeline => + assert(pipeline.uid === pipeline2.uid) + pipeline2.getStages match { + case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) => + assert(hashingTF.uid === hashingTF2.uid) + lrcv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded internal CrossValidator expected to be" + + s" LogisticRegression but found type ${other.getClass.getName}") + } + assert(lrcv.uid === lrcv2.uid) + assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(lrEvaluator.uid === lrcv2.getEvaluator.uid) + CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) + case other => + throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + + " but found: " + other.map(_.getClass.getName).mkString(", ")) + } + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" CrossValidator but found ${other.getClass.getName}") + } + } + + test("read/write: CrossValidator fails for extraneous Param") { + val lr = new LogisticRegression() + val lr2 = new LogisticRegression() + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .addGrid(lr2.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setEstimatorParamMaps(paramMaps) + withClue("CrossValidator.write failed to catch extraneous Param error") { + intercept[IllegalArgumentException] { + cv.write + } + } + } + + test("read/write: CrossValidatorModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6)) + cv.set(cv.estimator, lr) + .set(cv.evaluator, evaluator) + .set(cv.numFolds, 20) + .set(cv.estimatorParamMaps, paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getThreshold === lr2.getThreshold) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.bestModel match { + case lrModel2: LogisticRegressionModel => + assert(lrModel.uid === lrModel2.uid) + assert(lrModel.getThreshold === lrModel2.getThreshold) + assert(lrModel.coefficients === lrModel2.coefficients) + assert(lrModel.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + assert(cv.avgMetrics === cv2.avgMetrics) + } } -object CrossValidatorSuite { +object CrossValidatorSuite extends SparkFunSuite { + + /** + * Assert sequences of estimatorParamMaps are identical. + * Params must be simple types comparable with `===`. + */ + def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { + assert(pMaps.length === pMaps2.length) + pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => + assert(pMap.size === pMap2.size) + pMap.toSeq.foreach { case ParamPair(p, v) => + assert(pMap2.contains(p)) + assert(pMap2(p) === v) + } + } + } abstract class MyModel extends Model[MyModel] From fc4b792d287095d70379a51f117c225d8d857078 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Sun, 22 Nov 2015 21:51:42 -0800 Subject: [PATCH 1497/1824] [SPARK-11835] Adds a sidebar menu to MLlib's documentation This PR adds a sidebar menu when browsing the user guide of MLlib. It uses a YAML file to describe the structure of the documentation. It should be trivial to adapt this to the other projects. ![screen shot 2015-11-18 at 4 46 12 pm](https://cloud.githubusercontent.com/assets/7594753/11259591/a55173f4-8e17-11e5-9340-0aed79d66262.png) Author: Timothy Hunter Closes #9826 from thunterdb/spark-11835. --- docs/_data/menu-ml.yaml | 10 ++++ docs/_data/menu-mllib.yaml | 75 +++++++++++++++++++++++++ docs/_includes/nav-left-wrapper-ml.html | 8 +++ docs/_includes/nav-left.html | 17 ++++++ docs/_layouts/global.html | 24 +++++--- docs/css/main.css | 37 ++++++++++++ 6 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 docs/_data/menu-ml.yaml create mode 100644 docs/_data/menu-mllib.yaml create mode 100644 docs/_includes/nav-left-wrapper-ml.html create mode 100644 docs/_includes/nav-left.html diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml new file mode 100644 index 000000000000..dff3d33bf4ed --- /dev/null +++ b/docs/_data/menu-ml.yaml @@ -0,0 +1,10 @@ +- text: Feature extraction, transformation, and selection + url: ml-features.html +- text: Decision trees for classification and regression + url: ml-decision-tree.html +- text: Ensembles + url: ml-ensembles.html +- text: Linear methods with elastic-net regularization + url: ml-linear-methods.html +- text: Multilayer perceptron classifier + url: ml-ann.html diff --git a/docs/_data/menu-mllib.yaml b/docs/_data/menu-mllib.yaml new file mode 100644 index 000000000000..12d22abd5282 --- /dev/null +++ b/docs/_data/menu-mllib.yaml @@ -0,0 +1,75 @@ +- text: Data types + url: mllib-data-types.html +- text: Basic statistics + url: mllib-statistics.html + subitems: + - text: Summary statistics + url: mllib-statistics.html#summary-statistics + - text: Correlations + url: mllib-statistics.html#correlations + - text: Stratified sampling + url: mllib-statistics.html#stratified-sampling + - text: Hypothesis testing + url: mllib-statistics.html#hypothesis-testing + - text: Random data generation + url: mllib-statistics.html#random-data-generation +- text: Classification and regression + url: mllib-classification-regression.html + subitems: + - text: Linear models (SVMs, logistic regression, linear regression) + url: mllib-linear-methods.html + - text: Naive Bayes + url: mllib-naive-bayes.html + - text: decision trees + url: mllib-decision-tree.html + - text: ensembles of trees (Random Forests and Gradient-Boosted Trees) + url: mllib-ensembles.html + - text: isotonic regression + url: mllib-isotonic-regression.html +- text: Collaborative filtering + url: mllib-collaborative-filtering.html + subitems: + - text: alternating least squares (ALS) + url: mllib-collaborative-filtering.html#collaborative-filtering +- text: Clustering + url: mllib-clustering.html + subitems: + - text: k-means + url: mllib-clustering.html#k-means + - text: Gaussian mixture + url: mllib-clustering.html#gaussian-mixture + - text: power iteration clustering (PIC) + url: mllib-clustering.html#power-iteration-clustering-pic + - text: latent Dirichlet allocation (LDA) + url: mllib-clustering.html#latent-dirichlet-allocation-lda + - text: streaming k-means + url: mllib-clustering.html#streaming-k-means +- text: Dimensionality reduction + url: mllib-dimensionality-reduction.html + subitems: + - text: singular value decomposition (SVD) + url: mllib-dimensionality-reduction.html#singular-value-decomposition-svd + - text: principal component analysis (PCA) + url: mllib-dimensionality-reduction.html#principal-component-analysis-pca +- text: Feature extraction and transformation + url: mllib-feature-extraction.html +- text: Frequent pattern mining + url: mllib-frequent-pattern-mining.html + subitems: + - text: FP-growth + url: mllib-frequent-pattern-mining.html#fp-growth + - text: association rules + url: mllib-frequent-pattern-mining.html#association-rules + - text: PrefixSpan + url: mllib-frequent-pattern-mining.html#prefix-span +- text: Evaluation metrics + url: mllib-evaluation-metrics.html +- text: PMML model export + url: mllib-pmml-model-export.html +- text: Optimization (developer) + url: mllib-optimization.html + subitems: + - text: stochastic gradient descent + url: mllib-optimization.html#stochastic-gradient-descent-sgd + - text: limited-memory BFGS (L-BFGS) + url: mllib-optimization.html#limited-memory-bfgs-l-bfgs diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html new file mode 100644 index 000000000000..0103e890cc21 --- /dev/null +++ b/docs/_includes/nav-left-wrapper-ml.html @@ -0,0 +1,8 @@ +
    +
    +

    spark.ml package

    + {% include nav-left.html nav=include.nav-ml %} +

    spark.mllib package

    + {% include nav-left.html nav=include.nav-mllib %} +
    +
    \ No newline at end of file diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html new file mode 100644 index 000000000000..73176f413255 --- /dev/null +++ b/docs/_includes/nav-left.html @@ -0,0 +1,17 @@ +{% assign navurl = page.url | remove: 'index.html' %} + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 467ff7a03fb7..1b09e2221e17 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -124,16 +124,24 @@ -
    - {% if page.displayTitle %} -

    {{ page.displayTitle }}

    - {% else %} -

    {{ page.title }}

    - {% endif %} +
    - {{ content }} + {% if page.url contains "/ml" %} + {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} + {% endif %} -
    + +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} + + {{ content }} + +
    +
    diff --git a/docs/css/main.css b/docs/css/main.css index d770173be101..356b324d6303 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -39,8 +39,18 @@ margin-left: 10px; } +body .container-wrapper { + position: absolute; + width: 100%; + display: flex; +} + body #content { + position: relative; + line-height: 1.6; /* Inspired by Github's wiki style */ + background-color: white; + padding-left: 15px; } .title { @@ -155,3 +165,30 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { * AnchorJS (anchor links when hovering over headers) */ a.anchorjs-link:hover { text-decoration: none; } + + +/** + * The left navigation bar. + */ +.left-menu-wrapper { + position: absolute; + height: 100%; + + width: 256px; + margin-top: -20px; + padding-top: 20px; + background-color: #F0F8FC; +} + +.left-menu { + position: fixed; + max-width: 350px; + + padding-right: 10px; + width: 256px; +} + +.left-menu h3 { + margin-left: 10px; + line-height: 30px; +} \ No newline at end of file From d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 22 Nov 2015 21:56:07 -0800 Subject: [PATCH 1498/1824] [SPARK-11912][ML] ml.feature.PCA minor refactor Like [SPARK-11852](https://issues.apache.org/jira/browse/SPARK-11852), ```k``` is params and we should save it under ```metadata/``` rather than both under ```data/``` and ```metadata/```. Refactor the constructor of ```ml.feature.PCAModel``` to take only ```pc``` but construct ```mllib.feature.PCAModel``` inside ```transform```. Author: Yanbo Liang Closes #9897 from yanboliang/spark-11912. --- .../org/apache/spark/ml/feature/PCA.scala | 23 +++++++------- .../apache/spark/ml/feature/PCASuite.scala | 31 ++++++++----------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 32d7afee6e73..aa88cb03d23c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) - copyValues(new PCAModel(uid, pcaModel).setParent(this)) + copyValues(new PCAModel(uid, pcaModel.pc).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] { /** * :: Experimental :: * Model fitted by [[PCA]]. + * + * @param pc A principal components Matrix. Each column is one principal component. */ @Experimental class PCAModel private[ml] ( override val uid: String, - pcaModel: feature.PCAModel) + val pc: DenseMatrix) extends Model[PCAModel] with PCAParams with MLWritable { import PCAModel._ - /** a principal components Matrix. Each column is one principal component. */ - val pc: DenseMatrix = pcaModel.pc - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -124,6 +123,7 @@ class PCAModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val pcaModel = new feature.PCAModel($(k), pc) val pcaOp = udf { pcaModel.transform _ } dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) } @@ -139,7 +139,7 @@ class PCAModel private[ml] ( } override def copy(extra: ParamMap): PCAModel = { - val copied = new PCAModel(uid, pcaModel) + val copied = new PCAModel(uid, pc) copyValues(copied, extra).setParent(parent) } @@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] { private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { - private case class Data(k: Int, pc: DenseMatrix) + private case class Data(pc: DenseMatrix) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.getK, instance.pc) + val data = Data(instance.pc) val dataPath = new Path(path, "data").toString sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath) - .select("k", "pc") + val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath) + .select("pc") .head() - val oldModel = new feature.PCAModel(k, pc) - val model = new PCAModel(metadata.uid, oldModel) + val model = new PCAModel(metadata.uid, pc) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 5a21cd20ceed..edab21e6c307 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] - val model = new PCAModel("pca", new OldPCAModel(2, mat)) + val model = new PCAModel("pca", mat) ParamsSuite.checkParams(model) } @@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } } - test("read/write") { + test("PCA read/write") { + val t = new PCA() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setK(3) + testDefaultReadWrite(t) + } - def checkModelData(model1: PCAModel, model2: PCAModel): Unit = { - assert(model1.pc === model2.pc) - } - val allParams: Map[String, Any] = Map( - "k" -> 3, - "inputCol" -> "features", - "outputCol" -> "pca_features" - ) - val data = Seq( - (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))), - (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - ) - val df = sqlContext.createDataFrame(data).toDF("id", "features") - val pca = new PCA().setK(3) - testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData) + test("PCAModel read/write") { + val instance = new PCAModel("myPCAModel", + Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.pc === instance.pc) } } From 4be360d4ee6cdb4d06306feca38ddef5212608cf Mon Sep 17 00:00:00 2001 From: BenFradet Date: Sun, 22 Nov 2015 22:05:01 -0800 Subject: [PATCH 1499/1824] [SPARK-11902][ML] Unhandled case in VectorAssembler#transform There is an unhandled case in the transform method of VectorAssembler if one of the input columns doesn't have one of the supported type DoubleType, NumericType, BooleanType or VectorUDT. So, if you try to transform a column of StringType you get a cryptic "scala.MatchError: StringType". This PR aims to fix this, throwing a SparkException when dealing with an unknown column type. Author: BenFradet Closes #9885 from BenFradet/SPARK-11902. --- .../org/apache/spark/ml/feature/VectorAssembler.scala | 2 ++ .../spark/ml/feature/VectorAssemblerSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0feec0549852..801096fed27b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String) val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) Array.fill(numAttrs)(NumericAttribute.defaultAttr) } + case otherType => + throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index fb21ab6b9bf2..9c1c00f41ab1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -69,6 +69,17 @@ class VectorAssemblerSuite } } + test("transform should throw an exception in case of unsupported type") { + val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val assembler = new VectorAssembler() + .setInputCols(Array("a", "b", "c")) + .setOutputCol("features") + val thrown = intercept[SparkException] { + assembler.transform(df) + } + assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + } + test("ML attributes") { val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) From 94ce65dfcbba1fe3a1fc9d8002c37d9cd1a11336 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Mon, 23 Nov 2015 08:53:40 -0800 Subject: [PATCH 1500/1824] [SPARK-11628][SQL] support column datatype of char(x) to recognize HiveChar Can someone review my code to make sure I'm not missing anything? Thanks! Author: Xiu Guo Author: Xiu Guo Closes #9612 from xguo27/SPARK-11628. --- .../sql/catalyst/util/DataTypeParser.scala | 6 ++++- .../catalyst/util/DataTypeParserSuite.scala | 8 ++++-- .../spark/sql/sources/TableScanSuite.scala | 5 ++++ .../spark/sql/hive/HiveInspectors.scala | 25 ++++++++++++++++--- .../apache/spark/sql/hive/TableReader.scala | 3 +++ .../spark/sql/hive/client/HiveShim.scala | 3 ++- 6 files changed, 43 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala index 2b83651f9086..515c071c283b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala @@ -52,7 +52,8 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | "(?i)date".r ^^^ DateType | "(?i)timestamp".r ^^^ TimestampType | - varchar + varchar | + char protected lazy val fixedDecimalType: Parser[DataType] = ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { @@ -60,6 +61,9 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { DecimalType(precision.toInt, scale.toInt) } + protected lazy val char: Parser[DataType] = + "(?i)char".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + protected lazy val varchar: Parser[DataType] = "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala index 1e3409a9db6e..bebf70896547 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala @@ -49,7 +49,9 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) + checkDataType("ChaR(5)", StringType) checkDataType("varchAr(20)", StringType) + checkDataType("cHaR(27)", StringType) checkDataType("BINARY", BinaryType) checkDataType("array", ArrayType(DoubleType, true)) @@ -83,7 +85,8 @@ class DataTypeParserSuite extends SparkFunSuite { |struct< | struct:struct, | MAP:Map, - | arrAy:Array> + | arrAy:Array, + | anotherArray:Array> """.stripMargin, StructType( StructField("struct", @@ -91,7 +94,8 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: StructField("MAP", MapType(TimestampType, StringType), true) :: - StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) + StructField("arrAy", ArrayType(DoubleType, true), true) :: + StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) ) // A column name can be a reserved word in our DDL parser and SqlParser. checkDataType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 12af8068c398..26c1ff520406 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -85,6 +85,7 @@ case class AllDataTypesScan( Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -115,6 +116,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -154,6 +156,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { |dateField dAte, |timestampField tiMestamp, |varcharField varchaR(12), + |charField ChaR(18), |arrayFieldSimple Array, |arrayFieldComplex Array>>, |mapFieldSimple MAP, @@ -207,6 +210,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: StructField("varcharField", StringType, true) :: + StructField("charField", StringType, true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,6 +252,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { | dateField, | timestampField, | varcharField, + | charField, | arrayFieldSimple, | arrayFieldComplex, | mapFieldSimple, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 36f0708f9da3..95b57d6ad124 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} @@ -61,6 +61,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Type * Java Boxed Primitives: * org.apache.hadoop.hive.common.type.HiveVarchar + * org.apache.hadoop.hive.common.type.HiveChar * java.lang.String * java.lang.Integer * java.lang.Boolean @@ -75,6 +76,7 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Timestamp * Writables: * org.apache.hadoop.hive.serde2.io.HiveVarcharWritable + * org.apache.hadoop.hive.serde2.io.HiveCharWritable * org.apache.hadoop.io.Text * org.apache.hadoop.io.IntWritable * org.apache.hadoop.hive.serde2.io.DoubleWritable @@ -93,7 +95,8 @@ import org.apache.spark.unsafe.types.UTF8String * Struct: Object[] / java.util.List / java POJO * Union: class StandardUnion { byte tag; Object object } * - * NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type. + * NOTICE: HiveVarchar/HiveChar is not supported by catalyst, it will be simply considered as + * String type. * * * 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data @@ -137,6 +140,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector + * WritableConstantHiveCharObjectInspector * WritableConstantHiveDecimalObjectInspector * WritableConstantTimestampObjectInspector * WritableConstantIntObjectInspector @@ -259,6 +263,8 @@ private[hive] trait HiveInspectors { UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) + case poi: WritableConstantHiveCharObjectInspector => + UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -303,11 +309,15 @@ private[hive] trait HiveInspectors { case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector case pi: PrimitiveObjectInspector => pi match { - // We think HiveVarchar is also a String + // We think HiveVarchar/HiveChar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) case hvoi: HiveVarcharObjectInspector => UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) + case hvoi: HiveCharObjectInspector if hvoi.preferWritable() => + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue) + case hvoi: HiveCharObjectInspector => + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) case x: StringObjectInspector => @@ -377,6 +387,15 @@ private[hive] trait HiveInspectors { null } + case _: JavaHiveCharObjectInspector => + (o: Any) => + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveChar(s, s.size) + } else { + null + } + case _: JavaHiveDecimalObjectInspector => (o: Any) => if (o != null) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 69f481c49a65..70ee02823eeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -382,6 +382,9 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) + case oi: HiveCharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 48bbb21e6c1d..346840079b85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -321,7 +321,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet filters.collect { From 1a5baaa6517872b9a4fd6cd41c4b2cf1e390f6d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:13:59 -0800 Subject: [PATCH 1501/1824] [SPARK-11894][SQL] fix isNull for GetInternalRowField We should use `InternalRow.isNullAt` to check if the field is null before calling `InternalRow.getXXX` Thanks gatorsmile who discovered this bug. Author: Wenchen Fan Closes #9904 from cloud-fan/null. --- .../sql/catalyst/expressions/objects.scala | 23 ++++++++----------- .../org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 82317d338516..4a1f419f0ad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -236,11 +236,6 @@ case class NewInstance( } if (propagateNull) { - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" s""" @@ -531,15 +526,15 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val row = child.gen(ctx) - s""" - ${row.code} - final boolean ${ev.isNull} = ${row.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)}; - } - """ + nullSafeCodeGen(ctx, ev, eval => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9da02550b39c..cc8e4325fd2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -386,7 +386,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((JavaData(1), 1L), (JavaData(2), 1L))) } - ignore("Java encoder self join") { + test("Java encoder self join") { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -396,6 +396,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (JavaData(2), JavaData(1)), (JavaData(2), JavaData(2)))) } + + test("SPARK-11894: Incorrect results are returned when using null") { + val nullInt = null.asInstanceOf[java.lang.Integer] + val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + + checkAnswer( + ds1.joinWith(ds2, lit(true)), + ((nullInt, "1"), (nullInt, "1")), + ((new java.lang.Integer(22), "2"), (nullInt, "1")), + ((nullInt, "1"), (new java.lang.Integer(22), "2")), + ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + } } From f2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:15:40 -0800 Subject: [PATCH 1502/1824] [SPARK-11921][SQL] fix `nullable` of encoder schema Author: Wenchen Fan Closes #9906 from cloud-fan/nullable. --- .../catalyst/encoders/ExpressionEncoder.scala | 15 +++++++- .../encoders/ExpressionEncoderSuite.scala | 38 ++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6eeba1442c1f..7bc9aed0b204 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -54,8 +54,13 @@ object ExpressionEncoder { val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) val fromRowExpression = ScalaReflection.constructorFor[T] + val schema = ScalaReflection.schemaFor[T] match { + case ScalaReflection.Schema(s: StructType, _) => s + case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } + new ExpressionEncoder[T]( - toRowExpression.dataType, + schema, flat, toRowExpression.flatten, fromRowExpression, @@ -71,7 +76,13 @@ object ExpressionEncoder { encoders.foreach(_.assertUnresolved()) val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + case (e, i) => + val (dataType, nullable) = if (e.flat) { + e.schema.head.dataType -> e.schema.head.nullable + } else { + e.schema -> true + } + StructField(s"_${i + 1}", dataType, nullable) }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 76459b34a484..d6ca138672ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} -import org.apache.spark.sql.types.ArrayType +import org.apache.spark.sql.types.{StructType, ArrayType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -238,6 +238,42 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + test("nullable of encoder schema") { + def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { + assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) + } + + // test for flat encoders + checkNullable[Int](false) + checkNullable[Option[Int]](true) + checkNullable[java.lang.Integer](true) + checkNullable[String](true) + + // test for product encoders + checkNullable[(String, Int)](true, false) + checkNullable[(Int, java.lang.Long)](false, true) + + // test for nested product encoders + { + val schema = ExpressionEncoder[(Int, (String, Int))].schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + + // test for tupled encoders + { + val schema = ExpressionEncoder.tuple( + ExpressionEncoder[Int], + ExpressionEncoder[(String, Int)]).schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + } + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() outers.put(getClass.getName, this) private def encodeDecodeTest[T : ExpressionEncoder]( From 946b406519af58c79041217e6f93854b6cf80acd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:39:33 -0800 Subject: [PATCH 1503/1824] [SPARK-11913][SQL] support typed aggregate with complex buffer schema Author: Wenchen Fan Closes #9898 from cloud-fan/agg. --- .../aggregate/TypedAggregateExpression.scala | 25 +++++++---- .../spark/sql/DatasetAggregatorSuite.scala | 41 ++++++++++++++++++- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6ce41aaf01e2..a9719128a626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -23,9 +23,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -46,14 +45,12 @@ object TypedAggregateExpression { /** * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has * the following limitations: - * - It assumes the aggregator reduces and returns a single column of type `long`. - * - It might only work when there is a single aggregator in the first column. * - It assumes the aggregator has a zero, `0`. */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. - bEncoder: ExpressionEncoder[Any], // Should be bound. + unresolvedBEncoder: ExpressionEncoder[Any], cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -80,10 +77,14 @@ case class TypedAggregateExpression( override lazy val inputTypes: Seq[DataType] = Nil - override val aggBufferSchema: StructType = bEncoder.schema + override val aggBufferSchema: StructType = unresolvedBEncoder.schema override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + val bEncoder = unresolvedBEncoder + .resolve(aggBufferAttributes, OuterScopes.outerScopes) + .bind(aggBufferAttributes) + // Note: although this simply copies aggBufferAttributes, this common code can not be placed // in the superclass because that will lead to initialization ordering issues. override val inputAggBufferAttributes: Seq[AttributeReference] = @@ -93,12 +94,18 @@ case class TypedAggregateExpression( lazy val boundA = aEncoder.get private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { - // todo: need a more neat way to assign the value. var i = 0 while (i < aggBufferAttributes.length) { + val offset = mutableAggBufferOffset + i aggBufferSchema(i).dataType match { - case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i)) - case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i)) + case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) + case ByteType => buffer.setByte(offset, value.getByte(i)) + case ShortType => buffer.setShort(offset, value.getShort(i)) + case IntegerType => buffer.setInt(offset, value.getInt(i)) + case LongType => buffer.setLong(offset, value.getLong(i)) + case FloatType => buffer.setFloat(offset, value.getFloat(i)) + case DoubleType => buffer.setDouble(offset, value.getDouble(i)) + case other => buffer.update(offset, value.get(i, other)) } i += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 937758979001..19dce5d1e2f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L } case class AggData(a: Int, b: String) -object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { +object ClassInputAgg extends Aggregator[AggData, Int, Int] { /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: Int = 0 @@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { override def merge(b1: Int, b2: Int): Int = b1 + b2 } +object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: (Int, AggData) = 0 -> AggData(0, "0") + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: (Int, AggData)): Int = reduction._1 + + /** + * Merge two intermediate values + */ + override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = + (b1._1 + b2._1, b1._2) +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupBy(_.b).agg(ClassInputAgg.toColumn), ("one", 1)) } + + test("typed aggregation: complex input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ComplexBufferAgg.toColumn), + 2 + ) + + checkAnswer( + ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), + (1.5, 2)) + + checkAnswer( + ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn), + ("one", 1), ("two", 1)) + } } From 5fd86e4fc2e06d2403ca538ae417580c93b69e06 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 23 Nov 2015 10:41:17 -0800 Subject: [PATCH 1504/1824] [SPARK-7173][YARN] Add label expression support for application master Add label expression support for AM to restrict it runs on the specific set of nodes. I tested it locally and works fine. sryza and vanzin please help to review, thanks a lot. Author: jerryshao Closes #9800 from jerryshao/SPARK-7173. --- docs/running-on-yarn.md | 9 +++++++ .../org/apache/spark/deploy/yarn/Client.scala | 26 ++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index db6bfa69ee0f..925a1e0ba6fc 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -326,6 +326,15 @@ If you need a reference to the proper location to put log files in the YARN so t Otherwise, the client process will exit after submission. + + spark.yarn.am.nodeLabelExpression + (none) + + A YARN node label expression that restricts the set of nodes AM will be scheduled on. + Only versions of YARN greater than or equal to 2.6 support node label expressions, so when + running against earlier versions, this property will be ignored. + + spark.yarn.executor.nodeLabelExpression (none) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ba799884f568..a77a3e2420e2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -225,7 +225,31 @@ private[spark] class Client( val capability = Records.newRecord(classOf[Resource]) capability.setMemory(args.amMemory + amMemoryOverhead) capability.setVirtualCores(args.amCores) - appContext.setResource(capability) + + if (sparkConf.contains("spark.yarn.am.nodeLabelExpression")) { + try { + val amRequest = Records.newRecord(classOf[ResourceRequest]) + amRequest.setResourceName(ResourceRequest.ANY) + amRequest.setPriority(Priority.newInstance(0)) + amRequest.setCapability(capability) + amRequest.setNumContainers(1) + val amLabelExpression = sparkConf.get("spark.yarn.am.nodeLabelExpression") + val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String]) + method.invoke(amRequest, amLabelExpression) + + val setResourceRequestMethod = + appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest]) + setResourceRequestMethod.invoke(appContext, amRequest) + } catch { + case e: NoSuchMethodException => + logWarning("Ignoring spark.yarn.am.nodeLabelExpression because the version " + + "of YARN does not support it") + appContext.setResource(capability) + } + } else { + appContext.setResource(capability) + } + appContext } From 5231cd5acaae69d735ba3209531705cc222f3cfb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 10:45:23 -0800 Subject: [PATCH 1505/1824] [SPARK-11762][NETWORK] Account for active streams when couting outstanding requests. This way the timeout handling code can correctly close "hung" channels that are processing streams. Author: Marcelo Vanzin Closes #9747 from vanzin/SPARK-11762. --- .../network/client/StreamInterceptor.java | 12 ++++++++- .../client/TransportResponseHandler.java | 15 +++++++++-- .../TransportResponseHandlerSuite.java | 27 +++++++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index 02230a00e69f..88ba3ccebdf2 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -30,13 +30,19 @@ */ class StreamInterceptor implements TransportFrameDecoder.Interceptor { + private final TransportResponseHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private volatile long bytesRead; - StreamInterceptor(String streamId, long byteCount, StreamCallback callback) { + StreamInterceptor( + TransportResponseHandler handler, + String streamId, + long byteCount, + StreamCallback callback) { + this.handler = handler; this.streamId = streamId; this.byteCount = byteCount; this.callback = callback; @@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } @@ -65,8 +73,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); + handler.deactivateStream(); throw re; } else if (bytesRead == byteCount) { + handler.deactivateStream(); callback.onComplete(streamId); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index ed3f36af5804..cc88991b588c 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; private final Queue streamCallbacks; + private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ private final AtomicLong timeOfLastRequestNs; @@ -87,9 +89,15 @@ public void removeRpcRequest(long requestId) { } public void addStreamCallback(StreamCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); streamCallbacks.offer(callback); } + @VisibleForTesting + public void deactivateStream() { + streamActive = false; + } + /** * Fire the failure callback for all outstanding requests. This is called when we have an * uncaught exception or pre-mature connection termination. @@ -177,14 +185,16 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount, + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); frameDecoder.setInterceptor(interceptor); + streamActive = true; } catch (Exception e) { logger.error("Error installing stream handler.", e); + deactivateStream(); } } else { logger.error("Could not find callback for StreamResponse."); @@ -208,7 +218,8 @@ public void handle(ResponseMessage message) { /** Returns total number of outstanding requests (fetch requests + rpcs) */ public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size(); + return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + + (streamActive ? 1 : 0); } /** Returns the time in nanoseconds of when the last request was sent out. */ diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 17a03ebe88a9..30144f4a9fc7 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -28,12 +29,16 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.util.TransportFrameDecoder; public class TransportResponseHandlerSuite { @Test @@ -112,4 +117,26 @@ public void handleFailedRPC() { verify(callback, times(1)).onFailure((Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void testActiveStreams() { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamResponse response = new StreamResponse("stream", 1234L, null); + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(response); + assertEquals(1, handler.numOutstandingRequests()); + handler.deactivateStream(); + assertEquals(0, handler.numOutstandingRequests()); + + StreamFailure failure = new StreamFailure("stream", "uh-oh"); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(failure); + assertEquals(0, handler.numOutstandingRequests()); + } } From 98d7ec7df4bb115dbd84cb9acd744b6c8abfebd5 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 23 Nov 2015 11:51:29 -0800 Subject: [PATCH 1506/1824] [SPARK-11920][ML][DOC] ML LinearRegression should use correct dataset in examples and user guide doc ML ```LinearRegression``` use ```data/mllib/sample_libsvm_data.txt``` as dataset in examples and user guide doc, but it's actually classification dataset rather than regression dataset. We should use ```data/mllib/sample_linear_regression_data.txt``` instead. The deeper causes is that ```LinearRegression``` with "normal" solver can not solve this dataset correctly, may be due to the ill condition and unreasonable label. This issue has been reported at [SPARK-11918](https://issues.apache.org/jira/browse/SPARK-11918). It will confuse users if they run the example code but get exception, so we should make this change which can clearly illustrate the usage of ```LinearRegression``` algorithm. Author: Yanbo Liang Closes #9905 from yanboliang/spark-11920. --- .../examples/ml/JavaLinearRegressionWithElasticNetExample.java | 2 +- .../src/main/python/ml/linear_regression_with_elastic_net.py | 3 ++- .../examples/ml/LinearRegressionWithElasticNetExample.scala | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java index 593f8fb3e9fe..4ad7676c8d32 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java @@ -37,7 +37,7 @@ public static void main(String[] args) { // $example on$ // Load training data DataFrame training = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); + .load("data/mllib/sample_linear_regression_data.txt"); LinearRegression lr = new LinearRegression() .setMaxIter(10) diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py index b0278276330c..a4cd40cf2672 100644 --- a/examples/src/main/python/ml/linear_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -29,7 +29,8 @@ # $example on$ # Load training data - training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + training = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_linear_regression_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala index 5a51ece6f9ba..22c824cea84d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -33,7 +33,8 @@ object LinearRegressionWithElasticNetExample { // $example on$ // Load training data - val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val training = sqlCtx.read.format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt") val lr = new LinearRegression() .setMaxIter(10) From f6dcc6e96ad3f88563d717d5b6c45394b44db747 Mon Sep 17 00:00:00 2001 From: Mortada Mehyar Date: Mon, 23 Nov 2015 12:03:15 -0800 Subject: [PATCH 1507/1824] [SPARK-11837][EC2] python3 compatibility for launching ec2 m3 instances this currently breaks for python3 because `string` module doesn't have `letters` anymore, instead `ascii_letters` should be used Author: Mortada Mehyar Closes #9797 from mortada/python3_fix. --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9327e21e43db..9fd652a3df4c 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -595,7 +595,7 @@ def launch_cluster(conn, opts, cluster_name): dev = BlockDeviceType() dev.ephemeral_name = 'ephemeral%d' % i # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.letters[i + 1] + name = '/dev/sd' + string.ascii_letters[i + 1] block_map[name] = dev # Launch slaves From 1b6e938be836786bac542fa430580248161e5403 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 23 Nov 2015 13:19:10 -0800 Subject: [PATCH 1508/1824] [SPARK-4424] Remove spark.driver.allowMultipleContexts override in tests This patch removes `spark.driver.allowMultipleContexts=true` from our test configuration. The multiple SparkContexts check was originally disabled because certain tests suites in SQL needed to create multiple contexts. As far as I know, this configuration change is no longer necessary, so we should remove it in order to make it easier to find test cleanup bugs. Author: Josh Rosen Closes #9865 from JoshRosen/SPARK-4424. --- pom.xml | 2 -- project/SparkBuild.scala | 1 - 2 files changed, 3 deletions(-) diff --git a/pom.xml b/pom.xml index ad849112ce76..234fd5dea1a6 100644 --- a/pom.xml +++ b/pom.xml @@ -1958,7 +1958,6 @@ false false false - true true src @@ -1997,7 +1996,6 @@ 1 false false - true true __not_used__ diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 67724c4e9e41..f575f0012d59 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -632,7 +632,6 @@ object TestSettings { javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", - javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", From 1d9120201012213edb1971a09e0849336dbb9415 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 23 Nov 2015 13:44:30 -0800 Subject: [PATCH 1509/1824] [SPARK-11836][SQL] udf/cast should not create new SQLContext They should use the existing SQLContext. Author: Davies Liu Closes #9914 from davies/create_udf. --- python/pyspark/sql/column.py | 7 ++++--- python/pyspark/sql/functions.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 9ca8e1f264cf..81fd4e782628 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -346,9 +346,10 @@ def cast(self, dataType): if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) + from pyspark.sql import SQLContext + sc = SparkContext.getOrCreate() + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) else: raise TypeError("unexpected type: %s" % type(dataType)) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c3da513c1389..a1ca723bbd7a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1457,14 +1457,15 @@ def __init__(self, func, returnType, name=None): self._judf = self._create_judf(name) def _create_judf(self, name): + from pyspark.sql import SQLContext f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, From 242be7daed9b01d19794bb2cf1ac421fe5ab7262 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Mon, 23 Nov 2015 13:46:34 -0800 Subject: [PATCH 1510/1824] [SPARK-11910][STREAMING][DOCS] Update twitter4j dependency version Author: Luciano Resende Closes #9892 from lresende/SPARK-11910. --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 96b36b7a7320..ed6b28c28213 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -723,7 +723,7 @@ Some of these advanced sources are as follows. - **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. -- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using +- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by Twitter4J library. You can either get the public stream, or get the filtered stream based on a From 7cfa4c6bc36d97e459d4adee7b03d537d63c337e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:51:43 -0800 Subject: [PATCH 1511/1824] [SPARK-11865][NETWORK] Avoid returning inactive client in TransportClientFactory. There's a very narrow race here where it would be possible for the timeout handler to close a channel after the client factory verified that the channel was still active. This change makes sure the client is marked as being recently in use so that the timeout handler does not close it until a new timeout cycle elapses. Author: Marcelo Vanzin Closes #9853 from vanzin/SPARK-11865. --- .../spark/network/client/TransportClient.java | 9 ++++- .../client/TransportClientFactory.java | 15 ++++++-- .../client/TransportResponseHandler.java | 9 +++-- .../server/TransportChannelHandler.java | 36 ++++++++++++------- 4 files changed, 52 insertions(+), 17 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index a0ba223e340a..876fcd846791 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -73,10 +73,12 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; @Nullable private String clientId; + private volatile boolean timedOut; public TransportClient(Channel channel, TransportResponseHandler handler) { this.channel = Preconditions.checkNotNull(channel); this.handler = Preconditions.checkNotNull(handler); + this.timedOut = false; } public Channel getChannel() { @@ -84,7 +86,7 @@ public Channel getChannel() { } public boolean isActive() { - return channel.isOpen() || channel.isActive(); + return !timedOut && (channel.isOpen() || channel.isActive()); } public SocketAddress getSocketAddress() { @@ -263,6 +265,11 @@ public void onFailure(Throwable e) { } } + /** Mark this channel as having timed out. */ + public void timeOut() { + this.timedOut = true; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 42a4f664e697..659c47160c7b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -136,8 +136,19 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; + // Make sure that the channel will not timeout by updating the last use time of the + // handler. Then check that the client is still alive, in case it timed out before + // this code was able to update things. + TransportChannelHandler handler = cachedClient.getChannel().pipeline() + .get(TransportChannelHandler.class); + synchronized (handler) { + handler.getResponseHandler().updateTimeOfLastRequest(); + } + + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } } // If we reach here, we don't have an existing connection open. Let's create a new one. diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index cc88991b588c..be181e066082 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -71,7 +71,7 @@ public TransportResponseHandler(Channel channel) { } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); } @@ -80,7 +80,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -227,4 +227,9 @@ public long getTimeOfLastRequestNs() { return timeOfLastRequestNs.get(); } + /** Updates the time of the last request to the current system time. */ + public void updateTimeOfLastRequest() { + timeOfLastRequestNs.set(System.nanoTime()); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index f8fcd1c3d7d7..29d688a67578 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -116,20 +116,32 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc // there are outstanding requests, we also do a secondary consistency check to ensure // there's no race between the idle timeout and incrementing the numOutstandingRequests // (see SPARK-7003). - boolean isActuallyOverdue = - System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { - if (responseHandler.numOutstandingRequests() > 0) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); - } else if (closeIdleConnections) { - // While CloseIdleConnections is enable, we also close idle connection - ctx.close(); + // + // To avoid a race between TransportClientFactory.createClient() and this code which could + // result in an inactive client being returned, this needs to run in a synchronized block. + synchronized (this) { + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + client.timeOut(); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + client.timeOut(); + ctx.close(); + } } } } } + + public TransportResponseHandler getResponseHandler() { + return responseHandler; + } + } From c2467dadae8ce44010a912ee91c429310f8add65 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:54:19 -0800 Subject: [PATCH 1512/1824] [SPARK-11140][CORE] Transfer files using network lib when using NettyRpcEnv. This change abstracts the code that serves jars / files to executors so that each RpcEnv can have its own implementation; the akka version uses the existing HTTP-based file serving mechanism, while the netty versions uses the new stream support added to the network lib, which makes file transfers benefit from the easier security configuration of the network library, and should also reduce overhead overall. The change includes a small fix to TransportChannelHandler so that it propagates user events to downstream handlers. Author: Marcelo Vanzin Closes #9530 from vanzin/SPARK-11140. --- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../scala/org/apache/spark/SparkEnv.scala | 14 -- .../scala/org/apache/spark/rpc/RpcEnv.scala | 46 ++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 60 +++++++- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 138 ++++++++++++++++-- .../spark/rpc/netty/NettyStreamManager.scala | 63 ++++++++ .../scala/org/apache/spark/util/Utils.scala | 9 ++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 39 ++++- .../rpc/netty/NettyRpcHandlerSuite.scala | 10 +- docs/configuration.md | 2 + docs/security.md | 5 +- .../launcher/AbstractCommandBuilder.java | 2 +- .../client/TransportClientFactory.java | 6 +- .../server/TransportChannelHandler.java | 1 + 14 files changed, 356 insertions(+), 47 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index af4456c05b0a..b153a7b08e59 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } val key = if (!isLocal && scheme == "file") { - env.httpFileServer.addFile(new File(uri.getPath)) + env.rpcEnv.fileServer.addFile(new File(uri.getPath)) } else { schemeCorrectedPath } @@ -1630,7 +1630,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli var key = "" if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - key = env.httpFileServer.addJar(new File(path)) + key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) key = uri.getScheme match { @@ -1644,7 +1644,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { - env.httpFileServer.addJar(new File(fileName)) + env.rpcEnv.fileServer.addJar(new File(fileName)) } catch { case e: Exception => // For now just log an error but allow to go through so spark examples work. @@ -1655,7 +1655,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } else { try { - env.httpFileServer.addJar(new File(uri.getPath)) + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) } catch { case exc: FileNotFoundException => logError(s"Jar not found at $path") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 88df27f733f2..84230e32a446 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -66,7 +66,6 @@ class SparkEnv ( val blockTransferService: BlockTransferService, val blockManager: BlockManager, val securityManager: SecurityManager, - val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, val memoryManager: MemoryManager, @@ -91,7 +90,6 @@ class SparkEnv ( if (!isStopped) { isStopped = true pythonWorkers.values.foreach(_.stop()) - Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() shuffleManager.stop() broadcastManager.stop() @@ -367,17 +365,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val httpFileServer = - if (isDriver) { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) - server - } else { - null - } - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -422,7 +409,6 @@ object SparkEnv extends Logging { blockTransferService, blockManager, securityManager, - httpFileServer, sparkFilesDir, metricsSystem, memoryManager, diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a560fd10cdf7..3d7d281b0dd6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,6 +17,9 @@ package org.apache.spark.rpc +import java.io.File +import java.nio.channels.ReadableByteChannel + import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} @@ -132,8 +135,51 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. */ def deserialize[T](deserializationAction: () => T): T + + /** + * Return the instance of the file server used to serve files. This may be `null` if the + * RpcEnv is not operating in server mode. + */ + def fileServer: RpcEnvFileServer + + /** + * Open a channel to download a file from the given URI. If the URIs returned by the + * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to + * retrieve the files. + * + * @param uri URI with location of the file. + */ + def openChannel(uri: String): ReadableByteChannel + } +/** + * A server used by the RpcEnv to server files to other processes owned by the application. + * + * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or + * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`. + */ +private[spark] trait RpcEnvFileServer { + + /** + * Adds a file to be served by this RpcEnv. This is used to serve files from the driver + * to executors when they're stored on the driver's local file system. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addFile(file: File): String + + /** + * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using + * `SparkContext.addJar`. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addJar(file: File): String + +} private[spark] case class RpcEnvConfig( conf: SparkConf, diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 059a7e10ec12..94dbec593c31 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,6 +17,8 @@ package org.apache.spark.rpc.akka +import java.io.File +import java.nio.channels.ReadableByteChannel import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future @@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import akka.serialization.JavaSerializer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} @@ -41,7 +43,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * remove Akka from the dependencies. */ private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + val actorSystem: ActorSystem, + val securityManager: SecurityManager, + conf: SparkConf, + boundPort: Int) extends RpcEnv(conf) with Logging { private val defaultAddress: RpcAddress = { @@ -64,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] ( */ private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + private val _fileServer = new AkkaFileServer(conf, securityManager) + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { endpointToRef.put(endpoint, endpointRef) refToEndpoint.put(endpointRef, endpoint) @@ -223,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] ( override def shutdown(): Unit = { actorSystem.shutdown() + _fileServer.shutdown() } override def stop(endpoint: RpcEndpointRef): Unit = { @@ -241,6 +249,52 @@ private[spark] class AkkaRpcEnv private[akka] ( deserializationAction() } } + + override def openChannel(uri: String): ReadableByteChannel = { + throw new UnsupportedOperationException( + "AkkaRpcEnv's files should be retrieved using an HTTP client.") + } + + override def fileServer: RpcEnvFileServer = _fileServer + +} + +private[akka] class AkkaFileServer( + conf: SparkConf, + securityManager: SecurityManager) extends RpcEnvFileServer { + + @volatile private var httpFileServer: HttpFileServer = _ + + override def addFile(file: File): String = { + getFileServer().addFile(file) + } + + override def addJar(file: File): String = { + getFileServer().addJar(file) + } + + def shutdown(): Unit = { + if (httpFileServer != null) { + httpFileServer.stop() + } + } + + private def getFileServer(): HttpFileServer = { + if (httpFileServer == null) synchronized { + if (httpFileServer == null) { + httpFileServer = startFileServer() + } + } + httpFileServer + } + + private def startFileServer(): HttpFileServer = { + val fileServerPort = conf.getInt("spark.fileserver.port", 0) + val server = new HttpFileServer(conf, securityManager, fileServerPort) + server.initialize() + server + } + } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -249,7 +303,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( config.name, config.host, config.port, config.conf, config.securityManager) actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.conf, boundPort) + new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3ce359868039..68701f609f77 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -20,6 +20,7 @@ import java.io._ import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer +import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.Nullable @@ -45,27 +46,39 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = SparkTransportConf.fromSparkConf( + private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) + private val streamManager = new NettyStreamManager(this) + private val transportContext = new TransportContext(transportConf, - new NettyRpcHandler(dispatcher, this)) + new NettyRpcHandler(dispatcher, this, streamManager)) - private val clientFactory = { - val bootstraps: java.util.List[TransportClientBootstrap] = - if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) - } else { - java.util.Collections.emptyList[TransportClientBootstrap] - } - transportContext.createClientFactory(bootstraps) + private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + java.util.Collections.emptyList[TransportClientBootstrap] + } } + private val clientFactory = transportContext.createClientFactory(createClientBootstraps()) + + /** + * A separate client factory for file downloads. This avoids using the same RPC handler as + * the main RPC context, so that events caused by these clients are kept isolated from the + * main RPC traffic. + * + * It also allows for different configuration of certain properties, such as the number of + * connections per peer. + */ + @volatile private var fileDownloadFactory: TransportClientFactory = _ + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool @@ -292,6 +305,9 @@ private[netty] class NettyRpcEnv( if (clientConnectionExecutor != null) { clientConnectionExecutor.shutdownNow() } + if (fileDownloadFactory != null) { + fileDownloadFactory.close() + } } override def deserialize[T](deserializationAction: () => T): T = { @@ -300,6 +316,96 @@ private[netty] class NettyRpcEnv( } } + override def fileServer: RpcEnvFileServer = streamManager + + override def openChannel(uri: String): ReadableByteChannel = { + val parsedUri = new URI(uri) + require(parsedUri.getHost() != null, "Host name must be defined.") + require(parsedUri.getPort() > 0, "Port must be defined.") + require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") + + val pipe = Pipe.open() + val source = new FileDownloadChannel(pipe.source()) + try { + val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) + val callback = new FileDownloadCallback(pipe.sink(), source, client) + client.stream(parsedUri.getPath(), callback) + } catch { + case e: Exception => + pipe.sink().close() + source.close() + throw e + } + + source + } + + private def downloadClient(host: String, port: Int): TransportClient = { + if (fileDownloadFactory == null) synchronized { + if (fileDownloadFactory == null) { + val module = "files" + val prefix = "spark.rpc.io." + val clone = conf.clone() + + // Copy any RPC configuration that is not overridden in the spark.files namespace. + conf.getAll.foreach { case (key, value) => + if (key.startsWith(prefix)) { + val opt = key.substring(prefix.length()) + clone.setIfMissing(s"spark.$module.io.$opt", value) + } + } + + val ioThreads = clone.getInt("spark.files.io.threads", 1) + val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) + val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) + fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) + } + } + fileDownloadFactory.createClient(host, port) + } + + private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + + @volatile private var error: Throwable = _ + + def setError(e: Throwable): Unit = error = e + + override def read(dst: ByteBuffer): Int = { + if (error != null) { + throw error + } + source.read(dst) + } + + override def close(): Unit = source.close() + + override def isOpen(): Boolean = source.isOpen() + + } + + private class FileDownloadCallback( + sink: WritableByteChannel, + source: FileDownloadChannel, + client: TransportClient) extends StreamCallback { + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.remaining() > 0) { + sink.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + sink.close() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + logError(s"Error downloading stream $streamId.", cause) + source.setError(cause) + sink.close() + } + + } + } private[netty] object NettyRpcEnv extends Logging { @@ -420,7 +526,7 @@ private[netty] class NettyRpcEndpointRef( override def toString: String = s"NettyRpcEndpointRef(${_address})" - def toURI: URI = new URI(s"spark://${_address}") + def toURI: URI = new URI(_address.toString) final override def equals(that: Any): Boolean = that match { case other: NettyRpcEndpointRef => _address == other._address @@ -471,7 +577,9 @@ private[netty] case class RpcFailure(e: Throwable) * with different `RpcAddress` information). */ private[netty] class NettyRpcHandler( - dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + dispatcher: Dispatcher, + nettyEnv: NettyRpcEnv, + streamManager: StreamManager) extends RpcHandler with Logging { // TODO: Can we add connection callback (channel registered) to the underlying framework? // A variable to track whether we should dispatch the RemoteProcessConnected message. @@ -498,7 +606,7 @@ private[netty] class NettyRpcHandler( dispatcher.postRemoteMessage(messageToDispatch, callback) } - override def getStreamManager: StreamManager = new OneForOneStreamManager + override def getStreamManager: StreamManager = streamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] @@ -516,8 +624,8 @@ private[netty] class NettyRpcHandler( override def connectionTerminated(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) clients.remove(client) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) } else { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala new file mode 100644 index 000000000000..eb1d2604fb23 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io.File +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.server.StreamManager +import org.apache.spark.rpc.RpcEnvFileServer + +/** + * StreamManager implementation for serving files from a NettyRpcEnv. + */ +private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) + extends StreamManager with RpcEnvFileServer { + + private val files = new ConcurrentHashMap[String, File]() + private val jars = new ConcurrentHashMap[String, File]() + + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + throw new UnsupportedOperationException() + } + + override def openStream(streamId: String): ManagedBuffer = { + val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) + val file = ftype match { + case "files" => files.get(fname) + case "jars" => jars.get(fname) + case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") + } + + require(file != null, s"File not found: $streamId") + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } + + override def addFile(file: File): String = { + require(files.putIfAbsent(file.getName(), file) == null, + s"File ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/files/${file.getName()}" + } + + override def addJar(file: File): String = { + require(jars.putIfAbsent(file.getName(), file) == null, + s"JAR ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1b3acb8ef7f5..af632349c9ca 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,6 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer +import java.nio.channels.Channels import java.util.concurrent._ import java.util.{Locale, Properties, Random, UUID} import javax.net.ssl.HttpsURLConnection @@ -535,6 +536,14 @@ private[spark] object Utils extends Logging { val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { + case "spark" => + if (SparkEnv.get == null) { + throw new IllegalStateException( + "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") + } + val source = SparkEnv.get.rpcEnv.openChannel(url) + val is = Channels.newInputStream(source) + downloadFile(url, is, targetFile, fileOverwrite) case "http" | "https" | "ftp" => var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2f55006420ce..2b664c6313ef 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.rpc -import java.io.NotSerializableException +import java.io.{File, NotSerializableException} +import java.util.UUID +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -25,10 +27,14 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.io.Files +import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils /** * Common tests for an RpcEnv implementation. @@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() env = createRpcEnv(conf, "local", 0) + + val sparkEnv = mock(classOf[SparkEnv]) + when(sparkEnv.rpcEnv).thenReturn(env) + SparkEnv.set(sparkEnv) } override def afterAll(): Unit = { if (env != null) { env.shutdown() } + SparkEnv.set(null) } def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv @@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) } + test("file server") { + val conf = new SparkConf() + val tempDir = Utils.createTempDir() + val file = new File(tempDir, "file") + Files.write(UUID.randomUUID().toString(), file, UTF_8) + val jar = new File(tempDir, "jar") + Files.write(UUID.randomUUID().toString(), jar, UTF_8) + + val fileUri = env.fileServer.addFile(file) + val jarUri = env.fileServer.addJar(jar) + + val destDir = Utils.createTempDir() + val destFile = new File(destDir, file.getName()) + val destJar = new File(destDir, jar.getName()) + + val sm = new SecurityManager(conf) + val hc = SparkHadoopUtil.get.conf + Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) + Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) + + assert(Files.equal(file, destFile)) + assert(Files.equal(jar, destJar)) + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f9d8e80c98b6..ccca795683da 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -25,17 +25,19 @@ import org.mockito.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())). - thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + val sm = mock(classOf[StreamManager]) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) @@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { test("connectionTerminated") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) diff --git a/docs/configuration.md b/docs/configuration.md index c496146e3ed6..4de202d7f763 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1020,6 +1020,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the executor to listen on. This is used for communicating with the driver. + This is only relevant when using the Akka RPC backend. @@ -1027,6 +1028,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the driver's HTTP file server to listen on. + This is only relevant when using the Akka RPC backend. diff --git a/docs/security.md b/docs/security.md index 177109415180..e1af221d446b 100644 --- a/docs/security.md +++ b/docs/security.md @@ -149,7 +149,8 @@ configure those ports. (random) Schedule tasks spark.executor.port - Akka-based. Set to "0" to choose a port randomly. + Akka-based. Set to "0" to choose a port randomly. Only used if Akka RPC backend is + configured. Executor @@ -157,7 +158,7 @@ configure those ports. (random) File server for files and jars spark.fileserver.port - Jetty-based + Jetty-based. Only used if Akka RPC backend is configured. Executor diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 3ee6bd92e47f..55fe156cf665 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -148,7 +148,7 @@ List buildClassPath(String appClassPath) throws IOException { String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher"); + "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); if (prependClasses) { if (!isTesting) { System.err.println( diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 659c47160c7b..61bafc838004 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -170,8 +170,10 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } /** - * Create a completely new {@link TransportClient} to the given remote host / port - * But this connection is not pooled. + * Create a completely new {@link TransportClient} to the given remote host / port. + * This connection is not pooled. + * + * As with {@link #createClient(String, int)}, this method is blocking. */ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) throws IOException { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 29d688a67578..3164e0067903 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -138,6 +138,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } } } + ctx.fireUserEventTriggered(evt); } public TransportResponseHandler getResponseHandler() { From 9db5f601facfdaba6e4333a6b2d2e4a9f009c788 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 23 Nov 2015 16:33:26 -0800 Subject: [PATCH 1513/1824] [SPARK-9866][SQL] Speed up VersionsSuite by using persistent Ivy cache This patch attempts to speed up VersionsSuite by storing fetched Hive JARs in an Ivy cache that persists across tests runs. If `SPARK_VERSIONS_SUITE_IVY_PATH` is set, that path will be used for the cache; if it is not set, VersionsSuite will create a temporary Ivy cache which is deleted after the test completes. Author: Josh Rosen Closes #9624 from JoshRosen/SPARK-9866. --- .../apache/spark/sql/hive/client/VersionsSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index c6d034a23a1c..7bc13bc60d30 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -36,10 +36,12 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - // Do not use a temp path here to speed up subsequent executions of the unit test during - // development. - private val ivyPath = Some( - new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + // In order to speed up test execution during development or in Jenkins, you can specify the path + // of an existing Ivy cache: + private val ivyPath: Option[String] = { + sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( + Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) + } private def buildConf() = { lazy val warehousePath = Utils.createTempDir() From 105745645b12afbbc2a350518cb5853a88944183 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 23 Nov 2015 17:11:51 -0800 Subject: [PATCH 1514/1824] [SPARK-10560][PYSPARK][MLLIB][DOCS] Make StreamingLogisticRegressionWithSGD Python API equal to Scala one This is to bring the API documentation of StreamingLogisticReressionWithSGD and StreamingLinearRegressionWithSGC in line with the Scala versions. -Fixed the algorithm descriptions -Added default values to parameter descriptions -Changed StreamingLogisticRegressionWithSGD regParam to default to 0, as in the Scala version Author: Bryan Cutler Closes #9141 from BryanCutler/StreamingLogisticRegressionWithSGD-python-api-sync. --- python/pyspark/mllib/classification.py | 37 +++++++++++++++++--------- python/pyspark/mllib/regression.py | 32 ++++++++++++++-------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index aab4015ba80f..9e6f17ef6e94 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -652,21 +652,34 @@ def train(cls, data, lambda_=1.0): @inherit_doc class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LogisticRegression with SGD on a batch of data. - - The weights obtained at the end of training a stream are used as initial - weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Number of iterations run for each batch of data. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param regParam: L2 Regularization parameter. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a logistic regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream. + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param regParam: + L2 Regularization parameter. + (default: 0.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01, + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.0, convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 6f00d1df209c..13b3397501c0 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -734,17 +734,27 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LinearRegression with SGD on a batch of data. - - The problem minimized is (1 / n_samples) * (y - weights'X)**2. - After training on a batch of data, the weights obtained at the end of - training are used as initial weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Total number of iterations run. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a linear regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream (see `LinearRegressionWithSGD` for model equation). + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ From 026ea2eab1f3cde270e8a6391d002915f3e1c6e5 Mon Sep 17 00:00:00 2001 From: Stephen Samuel Date: Mon, 23 Nov 2015 19:52:12 -0800 Subject: [PATCH 1515/1824] Updated sql programming guide to include jdbc fetch size Author: Stephen Samuel Closes #9377 from sksamuel/master. --- docs/sql-programming-guide.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e347754055e7..d7b205c2fa0d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1820,6 +1820,7 @@ the Data Sources API. The following options are supported: register itself with the JDBC subsystem. + partitionColumn, lowerBound, upperBound, numPartitions @@ -1831,6 +1832,13 @@ the Data Sources API. The following options are supported: partitioned and returned. + + + fetchSize + + The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). + +
    From 8d57524662fad4a0760f3bc924e690c2a110e7f7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 23 Nov 2015 22:22:15 -0800 Subject: [PATCH 1516/1824] [SPARK-11933][SQL] Rename mapGroup -> mapGroups and flatMapGroup -> flatMapGroups. Based on feedback from Matei, this is more consistent with mapPartitions in Spark. Also addresses some of the cleanups from a previous commit that renames the type variables. Author: Reynold Xin Closes #9919 from rxin/SPARK-11933. --- ...nction.java => FlatMapGroupsFunction.java} | 2 +- ...upFunction.java => MapGroupsFunction.java} | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 36 +++++++++---------- .../apache/spark/sql/JavaDatasetSuite.java | 10 +++--- .../spark/sql/DatasetPrimitiveSuite.scala | 4 +-- .../org/apache/spark/sql/DatasetSuite.scala | 12 +++---- 6 files changed, 33 insertions(+), 33 deletions(-) rename core/src/main/java/org/apache/spark/api/java/function/{FlatMapGroupFunction.java => FlatMapGroupsFunction.java} (93%) rename core/src/main/java/org/apache/spark/api/java/function/{MapGroupFunction.java => MapGroupsFunction.java} (93%) diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java similarity index 93% rename from core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java rename to core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java index 18a2d733ca70..d7a80e7b129b 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java @@ -23,6 +23,6 @@ /** * A function that returns zero or more output records from each grouping key and its values. */ -public interface FlatMapGroupFunction extends Serializable { +public interface FlatMapGroupsFunction extends Serializable { Iterable call(K key, Iterator values) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java similarity index 93% rename from core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java rename to core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java index 4f3f222e064b..faa59eabc8b4 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java @@ -23,6 +23,6 @@ /** * Base interface for a map function used in GroupedDataset's mapGroup function. */ -public interface MapGroupFunction extends Serializable { +public interface MapGroupsFunction extends Serializable { R call(K key, Iterator values) throws Exception; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7f43ce16901b..793a86b13290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.expressions.Aggregator @Experimental class GroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], - tEncoder: Encoder[V], + vEncoder: Encoder[V], val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { @@ -53,12 +53,12 @@ class GroupedDataset[K, V] private[sql]( // queryexecution. private implicit val unresolvedKEncoder = encoderFor(kEncoder) - private implicit val unresolvedTEncoder = encoderFor(tEncoder) + private implicit val unresolvedVEncoder = encoderFor(vEncoder) private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) - private val resolvedTEncoder = - unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + private val resolvedVEncoder = + unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext @@ -76,7 +76,7 @@ class GroupedDataset[K, V] private[sql]( def keyAs[L : Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], - unresolvedTEncoder, + unresolvedVEncoder, queryExecution, dataAttributes, groupingAttributes) @@ -110,13 +110,13 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( f, resolvedKEncoder, - resolvedTEncoder, + resolvedVEncoder, groupingAttributes, logicalPlan)) } @@ -138,8 +138,8 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder) + def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) } /** @@ -158,9 +158,9 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) - flatMapGroup(func) + flatMapGroups(func) } /** @@ -179,8 +179,8 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - mapGroup((key, data) => f.call(key, data.asJava))(encoder) + def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava))(encoder) } /** @@ -192,8 +192,8 @@ class GroupedDataset[K, V] private[sql]( def reduce(f: (V, V) => V): Dataset[(K, V)] = { val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) - flatMapGroup(func) + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) + flatMapGroups(func) } /** @@ -213,7 +213,7 @@ class GroupedDataset[K, V] private[sql]( private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => - tc.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes) + tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) case _ => c } @@ -227,7 +227,7 @@ class GroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map( - _.withInputType(resolvedTEncoder, dataAttributes).named) + _.withInputType(resolvedVEncoder, dataAttributes).named) val keyColumn = if (groupingAttributes.length > 1) { Alias(CreateStruct(groupingAttributes), "key")() } else { @@ -304,7 +304,7 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.unresolvedTEncoder + implicit def uEnc: Encoder[U] = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index cf335efdd23b..67a3190cb7d4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,7 +170,7 @@ public Integer call(String v) throws Exception { } }, Encoders.INT()); - Dataset mapped = grouped.mapGroup(new MapGroupFunction() { + Dataset mapped = grouped.mapGroups(new MapGroupsFunction() { @Override public String call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -183,8 +183,8 @@ public String call(Integer key, Iterator values) throws Exception { Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); - Dataset flatMapped = grouped.flatMapGroup( - new FlatMapGroupFunction() { + Dataset flatMapped = grouped.flatMapGroups( + new FlatMapGroupsFunction() { @Override public Iterable call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -249,8 +249,8 @@ public void testGroupByColumn() { GroupedDataset grouped = ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); - Dataset mapped = grouped.mapGroup( - new MapGroupFunction() { + Dataset mapped = grouped.mapGroups( + new MapGroupsFunction() { @Override public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index d387710357be..f75d0961823c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.mapGroup { case (g, iter) => + val agged = grouped.mapGroups { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } @@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() val grouped = ds.groupBy(_.length) - val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) } + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } checkAnswer( agged, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cc8e4325fd2f..dbdd7ba14a5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, @@ -234,7 +234,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.flatMapGroup { case (g, iter) => + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } @@ -255,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, @@ -265,7 +265,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").keyAs[String] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -275,7 +275,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -285,7 +285,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From 6cf51a7007bd72eb93ade149ca9fc53be5b32a17 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 23 Nov 2015 22:22:50 -0800 Subject: [PATCH 1517/1824] [SPARK-11903] Remove --skip-java-test Per [pwendell's comments on SPARK-11903](https://issues.apache.org/jira/browse/SPARK-11903?focusedCommentId=15021511&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15021511) I'm removing this dead code. If we are concerned about preserving compatibility, I can instead leave the option in and add a warning. For example: ```sh echo "Warning: '--skip-java-test' is deprecated and has no effect." ;; ``` cc pwendell, srowen Author: Nicholas Chammas Closes #9924 from nchammas/make-distribution. --- make-distribution.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index d7d27e253f72..7b417fe7cf61 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -69,9 +69,6 @@ while (( "$#" )); do echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; - --skip-java-test) - SKIP_JAVA_TEST=true - ;; --with-tachyon) SPARK_TACHYON=true ;; From 4021a28ac30b65cb61cf1e041253847253a2d89f Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 23 Nov 2015 22:26:08 -0800 Subject: [PATCH 1518/1824] [SPARK-10707][SQL] Fix nullability computation in union output Author: Mikhail Bautin Closes #9308 from mbautin/SPARK-10707. --- .../plans/logical/basicOperators.scala | 11 +++++-- .../spark/sql/execution/basicOperators.scala | 9 ++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 31 +++++++++++++++++++ 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0c444482c5e4..737e62fd5921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -92,8 +92,10 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - // TODO: These aren't really the same attributes as nullability etc might change. - final override def output: Seq[Attribute] = left.output + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) + } final override lazy val resolved: Boolean = childrenResolved && @@ -115,7 +117,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + /** We don't use right.output because those rows get excluded from the set. */ + override def output: Seq[Attribute] = left.output +} case class Join( left: LogicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e79092efdaa3..d57b8e7a9ed6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -130,8 +130,13 @@ case class Sample( * Union two plans, without a distinct. This is UNION ALL in SQL. */ case class Union(children: Seq[SparkPlan]) extends SparkPlan { - // TODO: attributes output by union should be distinct for nullability purposes - override def output: Seq[Attribute] = children.head.output + override def output: Seq[Attribute] = { + children.tail.foldLeft(children.head.output) { case (currentOutput, child) => + currentOutput.zip(child.output).map { case (a1, a2) => + a1.withNullability(a1.nullable || a2.nullable) + } + } + } override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 167aea87de07..bb82b562aaaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1997,4 +1997,35 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } + + test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { + // This test produced an incorrect result of 1 before the SPARK-10707 fix because of the + // NullPropagation rule: COUNT(v) got replaced with COUNT(1) because the output column of + // UNION was incorrectly considered non-nullable: + checkAnswer( + sql("""SELECT count(v) FROM ( + | SELECT v FROM ( + | SELECT 'foo' AS v UNION ALL + | SELECT NULL AS v + | ) my_union WHERE isnull(v) + |) my_subview""".stripMargin), + Seq(Row(0))) + } + + test("SPARK-10707: nullability should be correctly propagated through set operations (2)") { + // This test uses RAND() to stop column pruning for Union and checks the resulting isnull + // value. This would produce an incorrect result before the fix in SPARK-10707 because the "v" + // column of the union was considered non-nullable. + checkAnswer( + sql( + """ + |SELECT a FROM ( + | SELECT ISNULL(v) AS a, RAND() FROM ( + | SELECT 'foo' AS v UNION ALL SELECT null AS v + | ) my_union + |) my_view + """.stripMargin), + Row(false) :: Row(true) :: Nil) + } + } From 12eea834d7382fbaa9c92182b682b8724049d7c1 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Tue, 24 Nov 2015 00:07:40 -0800 Subject: [PATCH 1519/1824] [SPARK-11897][SQL] Add @scala.annotations.varargs to sql functions Author: Xiu Guo Closes #9918 from xguo27/SPARK-11897. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b27b1340cce4..6137ce3a70fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -689,6 +689,7 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def array(colName: String, colNames: String*): Column = { array((colName +: colNames).map(col) : _*) } @@ -871,6 +872,7 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def struct(colName: String, colNames: String*): Column = { struct((colName +: colNames).map(col) : _*) } From 800bd799acf7f10a469d8d6537279953129eb2c6 Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Tue, 24 Nov 2015 09:03:32 +0000 Subject: [PATCH 1520/1824] [SPARK-11906][WEB UI] Speculation Tasks Cause ProgressBar UI Overflow When there are speculative tasks in the stage, running progress bar could overflow and goes hidden on a new line: ![image](https://cloud.githubusercontent.com/assets/4317392/11326841/5fd3482e-9142-11e5-8ca5-cb2f0c0c8964.png) 3 completed / 2 running (including 1 speculative) out of 4 total tasks This is a simple fix by capping the started tasks at `total - completed` tasks ![image](https://cloud.githubusercontent.com/assets/4317392/11326842/6bb67260-9142-11e5-90f0-37f9174878ec.png) I should note my preferred way to fix it is via css style ```css .progress { display: flex; } ``` which shifts the correction burden from driver to web browser. However I couldn't get selenium test to measure the position/dimension of the progress bar correctly to get this unit tested. It also has the side effect that the width will be calibrated so the running occupies 2 / 5 instead of 1 / 4. ![image](https://cloud.githubusercontent.com/assets/4317392/11326848/7b03e9f0-9142-11e5-89ad-bd99cb0647cf.png) All in all, since this cosmetic bug is minor enough, I suppose the original simple fix should be good enough. Author: Forest Fang Closes #9896 from saurfang/progressbar. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 4 +++- .../test/scala/org/apache/spark/ui/UIUtilsSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 25dcb604d9e5..84a1116a5c49 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -319,7 +319,9 @@ private[spark] object UIUtils extends Logging { skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) + // started + completed can be > total when there are speculative tasks + val boundedStarted = math.min(started, total - completed) + val startWidth = "width: %s%%".format((boundedStarted.toDouble/total)*100)
    diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 2b693c165180..dd8d5ec27f87 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -57,6 +57,16 @@ class UIUtilsSuite extends SparkFunSuite { ) } + test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { + val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div") + val expected = Seq( +
    , +
    + ) + assert(generated.sameElements(expected), + s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated") + } + private def verify( desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { val generated = makeDescription(desc, baseUrl) From d4a5e6f719079639ffd38470f4d8d1f6fde3228d Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Tue, 24 Nov 2015 23:24:49 +0800 Subject: [PATCH 1521/1824] [SPARK-11043][SQL] BugFix:Set the operator log in the thrift server. `SessionManager` will set the `operationLog` if the configuration `hive.server2.logging.operation.enabled` is true in version of hive 1.2.1. But the spark did not adapt to this change, so no matter enabled the configuration or not, spark thrift server will always log the warn message. PS: if `hive.server2.logging.operation.enabled` is false, it should log the warn message (the same as hive thrift server). Author: huangzhaowei Closes #9056 from SaintBacchus/SPARK-11043. --- .../SparkExecuteStatementOperation.scala | 8 ++++---- .../thriftserver/SparkSQLSessionManager.scala | 5 +++++ .../thriftserver/HiveThriftServer2Suites.scala | 16 +++++++++++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 82fef92dcb73..e022ee86a763 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -134,12 +134,12 @@ private[hive] class SparkExecuteStatementOperation( def getResultSetSchema: TableSchema = resultSchema - override def run(): Unit = { + override def runInternal(): Unit = { setState(OperationState.PENDING) setHasResultSet(true) // avoid no resultset for async run if (!runInBackground) { - runInternal() + execute() } else { val sparkServiceUGI = Utils.getUGI() @@ -151,7 +151,7 @@ private[hive] class SparkExecuteStatementOperation( val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { try { - runInternal() + execute() } catch { case e: HiveSQLException => setOperationException(e) @@ -188,7 +188,7 @@ private[hive] class SparkExecuteStatementOperation( } } - override def runInternal(): Unit = { + private def execute(): Unit = { statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index af4fcdf021bd..de4e9c62b57a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -41,6 +41,11 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) + // Create operation log root directory, if operation logging is enabled + if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { + invoke(classOf[SessionManager], this, "initOperationLogRootDir") + } + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) getAncestorField[Log](this, 3, "LOG").info( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 1dd898aa3835..139d8e897ba1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.{Await, Promise, future} +import scala.io.Source import scala.util.{Random, Try} import com.google.common.base.Charsets.UTF_8 @@ -507,6 +508,12 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(rs2.getInt(2) === 500) } } + + test("SPARK-11043 check operation log root directory") { + val expectedLine = + "Operation log root directory is created: " + operationLogPath.getAbsoluteFile + assert(Source.fromFile(logPath).getLines().exists(_.contains(expectedLine))) + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -642,7 +649,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" private val pidDir: File = Utils.createTempDir("thriftserver-pid") - private var logPath: File = _ + protected var logPath: File = _ + protected var operationLogPath: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] @@ -679,6 +687,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug @@ -706,6 +715,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl warehousePath.delete() metastorePath = Utils.createTempDir() metastorePath.delete() + operationLogPath = Utils.createTempDir() + operationLogPath.delete() logPath = null logTailingProcess = null @@ -782,6 +793,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() metastorePath = null + operationLogPath.delete() + operationLogPath = null + Option(logPath).foreach(_.delete()) logPath = null From 5889880fbe9628681042036892ef7ebd4f0857b4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 24 Nov 2015 23:32:05 +0800 Subject: [PATCH 1522/1824] [SPARK-11592][SQL] flush spark-sql command line history to history file Currently, `spark-sql` would not flush command history when exiting. Author: Daoyuan Wang Closes #9563 from adrian-wang/jline. --- .../hive/thriftserver/SparkSQLCLIDriver.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 6419002a2aa8..4b928e600b35 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -194,6 +194,22 @@ private[hive] object SparkSQLCLIDriver extends Logging { logWarning(e.getMessage) } + // add shutdown hook to flush the history to history file + Runtime.getRuntime.addShutdownHook(new Thread(new Runnable() { + override def run() = { + reader.getHistory match { + case h: FileHistory => + try { + h.flush() + } catch { + case e: IOException => + logWarning("WARNING: Failed to write command history file: " + e.getMessage) + } + case _ => + } + } + })) + // TODO: missing /* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") From be9dd1550c1816559d3d418a19c692e715f1c94e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 24 Nov 2015 09:20:09 -0800 Subject: [PATCH 1523/1824] =?UTF-8?q?[SPARK-11818][REPL]=20Fix=20ExecutorC?= =?UTF-8?q?lassLoader=20to=20lookup=20resources=20from=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …parent class loader Without patch, two additional tests of ExecutorClassLoaderSuite fails. - "resource from parent" - "resources from parent" Detailed explanation is here, https://issues.apache.org/jira/browse/SPARK-11818?focusedCommentId=15011202&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15011202 Author: Jungtaek Lim Closes #9812 from HeartSaVioR/SPARK-11818. --- .../spark/repl/ExecutorClassLoader.scala | 12 +++++++- .../spark/repl/ExecutorClassLoaderSuite.scala | 29 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index a976e96809cb..a8859fcd4584 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -34,7 +34,9 @@ import org.apache.spark.util.ParentClassLoader /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, * used to load classes defined by the interpreter when the REPL is used. - * Allows the user to specify if user class path should be first + * Allows the user to specify if user class path should be first. + * This class loader delegates getting/finding resources to parent loader, + * which makes sense until REPL never provide resource dynamically. */ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader with Logging { @@ -55,6 +57,14 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + override def getResource(name: String): URL = { + parentLoader.getResource(name) + } + + override def getResources(name: String): java.util.Enumeration[URL] = { + parentLoader.getResources(name) + } + override def findClass(name: String): Class[_] = { userClassPathFirst match { case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index a58eda12b112..c1211f7596b9 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -19,8 +19,13 @@ package org.apache.spark.repl import java.io.File import java.net.{URL, URLClassLoader} +import java.nio.charset.StandardCharsets +import java.util + +import com.google.common.io.Files import scala.concurrent.duration._ +import scala.io.Source import scala.language.implicitConversions import scala.language.postfixOps @@ -41,6 +46,7 @@ class ExecutorClassLoaderSuite val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") + val parentResourceNames = List("fake-resource.txt") var tempDir1: File = _ var tempDir2: File = _ var url1: String = _ @@ -54,6 +60,9 @@ class ExecutorClassLoaderSuite url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) + parentResourceNames.foreach { x => + Files.write("resource".getBytes(StandardCharsets.UTF_8), new File(tempDir2, x)) + } parentClassNames.foreach(TestUtils.createCompiledClass(_, tempDir2, "2")) } @@ -99,6 +108,26 @@ class ExecutorClassLoaderSuite } } + test("resource from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val is = classLoader.getResourceAsStream(resourceName) + assert(is != null, s"Resource $resourceName not found") + val content = Source.fromInputStream(is, "UTF-8").getLines().next() + assert(content.contains("resource"), "File doesn't contain 'resource'") + } + + test("resources from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) + assert(resources.hasMoreElements, s"Resource $resourceName not found") + val fileReader = Source.fromInputStream(resources.nextElement().openStream()).bufferedReader() + assert(fileReader.readLine().contains("resource"), "File doesn't contain 'resource'") + } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class // from the driver's class server would leak a HTTP connection, causing the class server's From e5aaae6e1145b8c25c4872b2992ab425da9c6f9b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2015 09:28:39 -0800 Subject: [PATCH 1524/1824] [SPARK-11942][SQL] fix encoder life cycle for CoGroup we should pass in resolved encodera to logical `CoGroup` and bind them in physical `CoGroup` Author: Wenchen Fan Closes #9928 from cloud-fan/cogroup. --- .../plans/logical/basicOperators.scala | 27 ++++++++++--------- .../org/apache/spark/sql/GroupedDataset.scala | 4 ++- .../spark/sql/execution/basicOperators.scala | 20 +++++++------- .../org/apache/spark/sql/DatasetSuite.scala | 12 +++++++++ 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 737e62fd5921..5665fd7e5f41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -553,19 +553,22 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], + def apply[Key, Left, Right, Result : Encoder]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan): CoGroup[K, Left, Right, R] = { + right: LogicalPlan): CoGroup[Key, Left, Right, Result] = { CoGroup( func, - encoderFor[K], - encoderFor[Left], - encoderFor[Right], - encoderFor[R], - encoderFor[R].schema.toAttributes, + keyEnc, + leftEnc, + rightEnc, + encoderFor[Result], + encoderFor[Result].schema.toAttributes, leftGroup, rightGroup, left, @@ -577,12 +580,12 @@ object CoGroup { * A relation produced by applying `func` to each grouping key and associated values from left and * right children. */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[Result], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 793a86b13290..a10a89342fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -304,11 +304,13 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( f, + this.resolvedKEncoder, + this.resolvedVEncoder, + other.resolvedVEncoder, this.groupingAttributes, other.groupingAttributes, this.logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d57b8e7a9ed6..a42aea0b96d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -375,12 +375,12 @@ case class MapGroups[K, T, U]( * iterators containing all elements in the group from left and right side. * The result of this function is encoded and flattened before being output. */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[Result], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], @@ -397,15 +397,17 @@ case class CoGroup[K, Left, Right, R]( left.execute().zipPartitions(right.execute()) { (leftData, rightData) => val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val groupKeyEncoder = kEncoder.bind(leftGroup) + val boundKeyEnc = keyEnc.bind(leftGroup) + val boundLeftEnc = leftEnc.bind(left.output) + val boundRightEnc = rightEnc.bind(right.output) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => val result = func( - groupKeyEncoder.fromRow(key), - leftResult.map(leftEnc.fromRow), - rightResult.map(rightEnc.fromRow)) - result.map(rEncoder.toRow) + boundKeyEnc.fromRow(key), + leftResult.map(boundLeftEnc.fromRow), + rightResult.map(boundRightEnc.fromRow)) + result.map(resultEnc.toRow) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dbdd7ba14a5b..13eede1b17d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -340,6 +340,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } + test("cogroup with complex data") { + val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS() + val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS() + val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) + } + + checkAnswer( + cogrouped, + 1 -> "a", 2 -> "bc", 3 -> "d") + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") From 56a0aba0a60326ba026056c9a23f3f6ec7258c19 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 24 Nov 2015 09:52:53 -0800 Subject: [PATCH 1525/1824] [SPARK-11952][ML] Remove duplicate ml examples Remove duplicate ml examples (only for ml). mengxr Author: Yanbo Liang Closes #9933 from yanboliang/SPARK-11685. --- .../main/python/ml/gradient_boosted_trees.py | 82 ----------------- .../src/main/python/ml/logistic_regression.py | 66 -------------- .../main/python/ml/random_forest_example.py | 87 ------------------- 3 files changed, 235 deletions(-) delete mode 100644 examples/src/main/python/ml/gradient_boosted_trees.py delete mode 100644 examples/src/main/python/ml/logistic_regression.py delete mode 100644 examples/src/main/python/ml/random_forest_example.py diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py deleted file mode 100644 index c3bf8aa2eb1e..000000000000 --- a/examples/src/main/python/ml/gradient_boosted_trees.py +++ /dev/null @@ -1,82 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import GBTClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import GBTRegressor -from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. -Note: GBTClassifier only supports binary classification currently -Run with: - bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py -""" - - -def testClassification(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = BinaryClassificationMetrics(predictionAndLabels) - print("AUC %.3f" % metrics.areaUnderROC) - - -def testRegression(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGBTExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py deleted file mode 100644 index 4cd027fdfbe8..000000000000 --- a/examples/src/main/python/ml/logistic_regression.py +++ /dev/null @@ -1,66 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.evaluation import MulticlassMetrics -from pyspark.ml.feature import StringIndexer -from pyspark.sql import SQLContext - -""" -A simple example demonstrating a logistic regression with elastic net regularization Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression.py -""" - -if __name__ == "__main__": - - if len(sys.argv) > 1: - print("Usage: logistic_regression", file=sys.stderr) - exit(-1) - - sc = SparkContext(appName="PythonLogisticRegressionExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [training, test] = td.randomSplit([0.7, 0.3]) - - lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") - lr.setElasticNetParam(0.8) - - # Fit the model - lrModel = lr.fit(training) - - predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py deleted file mode 100644 index dc6a77867019..000000000000 --- a/examples/src/main/python/ml/random_forest_example.py +++ /dev/null @@ -1,87 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import RandomForestClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import RandomForestRegressor -from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics -from pyspark.mllib.util import MLUtils -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a RandomForest Classification/Regression Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/random_forest_example.py -""" - - -def testClassification(train, test): - # Train a RandomForest model. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - # Note: Use larger numTrees in practice. - - rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - -def testRegression(train, test): - # Train a RandomForest model. - # Note: Use larger numTrees in practice. - - rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() From 9e24ba667e43290fbaa3cacb93cf5d9be790f1fd Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 24 Nov 2015 09:54:55 -0800 Subject: [PATCH 1526/1824] [SPARK-11521][ML][DOC] Document that Logistic, Linear Regression summaries ignore weight col Doc for 1.6 that the summaries mostly ignore the weight column. To be corrected for 1.7 CC: mengxr thunterdb Author: Joseph K. Bradley Closes #9927 from jkbradley/linregsummary-doc. --- .../ml/classification/LogisticRegression.scala | 18 ++++++++++++++++++ .../spark/ml/regression/LinearRegression.scala | 15 +++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 418bbdc9a058..d320d64dd90d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -755,23 +755,35 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns the receiver operating characteristic (ROC) curve, * which is an Dataframe having two fields (FPR, TPR) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") /** * Computes the area under the receiver operating characteristic (ROC) curve. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() /** * Returns the precision-recall curve, which is an Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val fMeasureByThreshold: DataFrame = { binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") @@ -781,6 +793,9 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, precision) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val precisionByThreshold: DataFrame = { binaryMetrics.precisionByThreshold().toDF("threshold", "precision") @@ -790,6 +805,9 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, recall) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val recallByThreshold: DataFrame = { binaryMetrics.recallByThreshold().toDF("threshold", "recall") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 70ccec766c47..1db91666f21a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -540,6 +540,9 @@ class LinearRegressionSummary private[regression] ( * Returns the explained variance regression score. * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -547,6 +550,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -554,6 +560,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -561,6 +570,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -568,6 +580,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 From 52bc25c8e26d4be250d8ff7864067528f4f98592 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 24 Nov 2015 09:56:17 -0800 Subject: [PATCH 1527/1824] [SPARK-11847][ML] Model export/import for spark.ml: LDA Add read/write support to LDA, similar to ALS. save/load for ml.LocalLDAModel is done. For DistributedLDAModel, I'm not sure if we can invoke save on the mllib.DistributedLDAModel directly. I'll send update after some test. Author: Yuhao Yang Closes #9894 from hhbyyh/ldaMLsave. --- .../org/apache/spark/ml/clustering/LDA.scala | 110 +++++++++++++++++- .../spark/mllib/clustering/LDAModel.scala | 4 +- .../apache/spark/ml/clustering/LDASuite.scala | 44 ++++++- 3 files changed, 150 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 92e05815d6a3..830510b1698d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} import org.apache.spark.ml.param._ +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, @@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] ( @Since("1.6.0") override val uid: String, @Since("1.6.0") val vocabSize: Int, @Since("1.6.0") @transient protected val sqlContext: SQLContext) - extends Model[LDAModel] with LDAParams with Logging { + extends Model[LDAModel] with LDAParams with Logging with MLWritable { // NOTE to developers: // This abstraction should contain all important functionality for basic LDA usage. @@ -486,6 +487,64 @@ class LocalLDAModel private[ml] ( @Since("1.6.0") override def isDistributed: Boolean = false + + @Since("1.6.0") + override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this) +} + + +@Since("1.6.0") +object LocalLDAModel extends MLReadable[LocalLDAModel] { + + private[LocalLDAModel] + class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter { + + private case class Data( + vocabSize: Int, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val oldModel = instance.oldLocalModel + val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, + oldModel.topicConcentration, oldModel.gammaShape) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LocalLDAModelReader extends MLReader[LocalLDAModel] { + + private val className = classOf[LocalLDAModel].getName + + override def load(path: String): LocalLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", + "gammaShape") + .head() + val vocabSize = data.getAs[Int](0) + val topicsMatrix = data.getAs[Matrix](1) + val docConcentration = data.getAs[Vector](2) + val topicConcentration = data.getAs[Double](3) + val gammaShape = data.getAs[Double](4) + val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) + val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader + + @Since("1.6.0") + override def load(path: String): LocalLDAModel = super.load(path) } @@ -562,6 +621,45 @@ class DistributedLDAModel private[ml] ( */ @Since("1.6.0") lazy val logPrior: Double = oldDistributedModel.logPrior + + @Since("1.6.0") + override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this) +} + + +@Since("1.6.0") +object DistributedLDAModel extends MLReadable[DistributedLDAModel] { + + private[DistributedLDAModel] + class DistributedWriter(instance: DistributedLDAModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val modelPath = new Path(path, "oldModel").toString + instance.oldDistributedModel.save(sc, modelPath) + } + } + + private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] { + + private val className = classOf[DistributedLDAModel].getName + + override def load(path: String): DistributedLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val modelPath = new Path(path, "oldModel").toString + val oldModel = OldDistributedLDAModel.load(sc, modelPath) + val model = new DistributedLDAModel( + metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader + + @Since("1.6.0") + override def load(path: String): DistributedLDAModel = super.load(path) } @@ -593,7 +691,8 @@ class DistributedLDAModel private[ml] ( @Since("1.6.0") @Experimental class LDA @Since("1.6.0") ( - @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams { + @Since("1.6.0") override val uid: String) + extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("lda")) @@ -695,7 +794,7 @@ class LDA @Since("1.6.0") ( } -private[clustering] object LDA { +private[clustering] object LDA extends DefaultParamsReadable[LDA] { /** Get dataset for spark.mllib LDA */ def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { @@ -706,4 +805,7 @@ private[clustering] object LDA { (docId, features) } } + + @Since("1.6.0") + override def load(path: String): LDA = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index cd520f09bd46..7384d065a2ea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -187,11 +187,11 @@ abstract class LDAModel private[clustering] extends Saveable { * @param topics Inferred topics (vocabSize x k matrix). */ @Since("1.3.0") -class LocalLDAModel private[clustering] ( +class LocalLDAModel private[spark] ( @Since("1.3.0") val topics: Matrix, @Since("1.5.0") override val docConcentration: Vector, @Since("1.5.0") override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double = 100) + override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { @Since("1.3.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index b634d31cc34f..97dbfd9a4314 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -39,10 +40,24 @@ object LDASuite { }.map(v => new TestRow(v)) sql.createDataFrame(rdd) } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "checkpointInterval" -> 30, + "learningOffset" -> 1023.0, + "learningDecay" -> 0.52, + "subsamplingRate" -> 0.051 + ) } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { val k: Int = 5 val vocabSize: Int = 30 @@ -218,4 +233,29 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val lp = model.logPrior assert(lp <= 0.0 && lp != Double.NegativeInfinity) } + + test("read/write LocalLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + } + + test("read/write DistributedLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) + } } From 19530da6903fa59b051eec69b9c17e231c68454b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2015 11:09:01 -0800 Subject: [PATCH 1528/1824] [SPARK-11926][SQL] unify GetStructField and GetInternalRowField Author: Wenchen Fan Closes #9909 from cloud-fan/get-struct. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/analysis/unresolved.scala | 8 +++---- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/complexTypeExtractors.scala | 18 ++++++++-------- .../expressions/namedExpressions.scala | 4 ++-- .../sql/catalyst/expressions/objects.scala | 21 ------------------- .../expressions/ComplexTypeSuite.scala | 4 ++-- 9 files changed, 21 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 476becec4dd5..d133ad3f0d89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection { /** Returns the current path with a field at ordinal extracted. */ def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) + .map(p => GetStructField(p, ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) /** Returns the current path or `BoundReference`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6485bdfb3023..1b2a8dc4c7f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu if (attribute.isDefined) { // This target resolved to an attribute in child. It must be a struct. Expand it. attribute.get.dataType match { - case s: StructType => { - s.fields.map( f => { - val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get) + case s: StructType => s.zipWithIndex.map { + case (f, i) => + val extract = GetStructField(attribute.get, i) Alias(extract, target.get + "." + f.name)() - }) } + case _ => { throw new AnalysisException("Can only star expand struct data types. Attribute: `" + target.get + "`") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7bc9aed0b204..0c10a56c555f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -111,7 +111,7 @@ object ExpressionEncoder { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt) + case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index fa553e7c5324..67518f52d4a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -220,7 +220,7 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(GetInternalRowField(input, i, f.dataType))) + constructorFor(GetStructField(input, i))) } CreateExternalRow(convertedFields) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 540ed3500616..169435a10ea2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyString: String = { transform { - case a: AttributeReference => PrettyAttribute(a.name) + case a: AttributeReference => PrettyAttribute(a.name, a.dataType) case u: UnresolvedAttribute => PrettyAttribute(u.name) }.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index f871b737fff3..10ce10aaf6da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -51,7 +51,7 @@ object ExtractValue { case (StructType(fields), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + GetStructField(child, ordinal, Some(fieldName)) case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString @@ -97,18 +97,18 @@ object ExtractValue { * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. - * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]]. + * + * Note that we can pass in the field name directly to keep case preserving in `toString`. + * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ -case class GetStructField(child: Expression, field: StructField, ordinal: Int) +case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - override def dataType: DataType = child.dataType match { - case s: StructType => s(ordinal).dataType - // This is a hack to avoid breaking existing code until we remove the need for the struct field - case _ => field.dataType - } + private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) + + override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${field.name}" + override def toString: String = s"$child.${name.getOrElse(field.name)}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, field.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 00b7970bd16c..26b6aca79971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -273,7 +273,8 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute with Unevaluable { +case class PrettyAttribute(name: String, dataType: DataType = NullType) + extends Attribute with Unevaluable { override def toString: String = name @@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable { override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException - override def dataType: DataType = NullType } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 4a1f419f0ad8..62d09f0f5510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { } } -case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType) - extends UnaryExpression { - - override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ - }) - } -} - /** * Serializes an input object using a generic serializer (Kryo or Java). * @param kryo if true, use Kryo. Otherwise, use Java. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e60990aeb423..62fd47234b33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) + val index = fields.indexWhere(_.name == fieldName) + GetStructField(expr, index) } } From 81012546ee5a80d2576740af0dad067b0f5962c5 Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 24 Nov 2015 12:22:33 -0800 Subject: [PATCH 1529/1824] [SPARK-11872] Prevent the call to SparkContext#stop() in the listener bus's thread This is continuation of SPARK-11761 Andrew suggested adding this protection. See tail of https://github.com/apache/spark/pull/9741 Author: tedyu Closes #9852 from tedyu/master. --- .../scala/org/apache/spark/SparkContext.scala | 4 +++ .../spark/scheduler/SparkListenerSuite.scala | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b153a7b08e59..e19ba113702c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1694,6 +1694,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Shut down the SparkContext. def stop() { + if (AsynchronousListenerBus.withinListenerThread.value) { + throw new SparkException("Cannot stop SparkContext within listener thread of" + + " AsynchronousListenerBus") + } // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. if (!stopped.compareAndSet(false, true)) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 84e545851f49..f20d5be7c0ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.scalatest.Matchers +import org.apache.spark.SparkException import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} @@ -36,6 +37,21 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L + test("don't call sc.stop in listener") { + sc = new SparkContext("local", "SparkListenerSuite") + val listener = new SparkContextStoppingListener(sc) + val bus = new LiveListenerBus + bus.addListener(listener) + + // Starting listener bus should flush all buffered events + bus.start(sc) + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + bus.stop() + assert(listener.sparkExSeen) + } + test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus @@ -443,6 +459,21 @@ private class BasicJobCounter extends SparkListener { override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } +/** + * A simple listener that tries to stop SparkContext. + */ +private class SparkContextStoppingListener(val sc: SparkContext) extends SparkListener { + @volatile var sparkExSeen = false + override def onJobEnd(job: SparkListenerJobEnd): Unit = { + try { + sc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } +} + private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { var count = 0 override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 From f3152722791b163fa66597b3684009058195ba33 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 12:54:37 -0800 Subject: [PATCH 1530/1824] [SPARK-11946][SQL] Audit pivot API for 1.6. Currently pivot's signature looks like ```scala scala.annotation.varargs def pivot(pivotColumn: Column, values: Column*): GroupedData scala.annotation.varargs def pivot(pivotColumn: String, values: Any*): GroupedData ``` I think we can remove the one that takes "Column" types, since callers should always be passing in literals. It'd also be more clear if the values are not varargs, but rather Seq or java.util.List. I also made similar changes for Python. Author: Reynold Xin Closes #9929 from rxin/SPARK-11946. --- .../apache/spark/scheduler/DAGScheduler.scala | 1 - python/pyspark/sql/group.py | 12 +- .../sql/catalyst/expressions/literals.scala | 1 + .../org/apache/spark/sql/GroupedData.scala | 154 ++++++++++-------- .../apache/spark/sql/JavaDataFrameSuite.java | 16 ++ .../spark/sql/DataFramePivotSuite.scala | 21 +-- .../apache/spark/sql/test/SQLTestData.scala | 1 + 7 files changed, 125 insertions(+), 81 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ae725b467d8c..77a184dfe4be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1574,7 +1574,6 @@ class DAGScheduler( } def stop() { - logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() eventProcessLoop.stop() taskScheduler.stop() diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 227f40bc3cf5..d8ed7eb2dda6 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -168,20 +168,24 @@ def sum(self, *cols): """ @since(1.6) - def pivot(self, pivot_col, *values): + def pivot(self, pivot_col, values=None): """Pivots a column of the current DataFrame and preform the specified aggregation. :param pivot_col: Column to pivot :param values: Optional list of values of pivotColumn that will be translated to columns in the output data frame. If values are not provided the method with do an immediate call to .distinct() on the pivot column. - >>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect() + + >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ - jgd = self._jdf.pivot(_to_java_column(pivot_col), - _to_seq(self.sql_ctx._sc, values, _create_column_from_literal)) + if values is None: + jgd = self._jdf.pivot(pivot_col) + else: + jgd = self._jdf.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e34fd49be838..68ec688c99f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -44,6 +44,7 @@ object Literal { case a: Array[Byte] => Literal(a, BinaryType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) + case v: Literal => v case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 63dd7fbcbe9e..ee7150cbbfbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} -import org.apache.spark.sql.types.{StringType, NumericType} +import org.apache.spark.sql.types.NumericType /** @@ -282,74 +282,96 @@ class GroupedData protected[sql]( } /** - * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified - * aggregation. - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) - * // Or without specifying column values - * df.groupBy($"year").pivot($"course").agg(sum($"earnings")) - * }}} - * @param pivotColumn Column to pivot - * @param values Optional list of values of pivotColumn that will be translated to columns in the - * output data frame. If values are not provided the method with do an immediate - * call to .distinct() on the pivot column. - * @since 1.6.0 - */ - @scala.annotation.varargs - def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match { - case _: GroupedData.PivotType => - throw new UnsupportedOperationException("repeated pivots are not supported") - case GroupedData.GroupByType => - val pivotValues = if (values.nonEmpty) { - values.map { - case Column(literal: Literal) => literal - case other => - throw new UnsupportedOperationException( - s"The values of a pivot must be literals, found $other") - } - } else { - // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) - // Get the distinct values of the column and sort them so its consistent - val values = df.select(pivotColumn) - .distinct() - .sort(pivotColumn) - .map(_.get(0)) - .take(maxValues + 1) - .map(Literal(_)).toSeq - if (values.length > maxValues) { - throw new RuntimeException( - s"The pivot column $pivotColumn has more than $maxValues distinct values, " + - "this could indicate an error. " + - "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " + - s"to at least the number of distinct values of the pivot column.") - } - values - } - new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) - case _ => - throw new UnsupportedOperationException("pivot is only supported after a groupBy") + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @since 1.6.0 + */ + def pivot(pivotColumn: String): GroupedData = { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) + .map(_.get(0)) + .take(maxValues + 1) + .toSeq + + if (values.length > maxValues) { + throw new AnalysisException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + + "to at least the number of distinct values of the pivot column.") + } + + pivot(pivotColumn, values) } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") - * // Or without specifying column values - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * @param pivotColumn Column to pivot - * @param values Optional list of values of pivotColumn that will be translated to columns in the - * output data frame. If values are not provided the method with do an immediate - * call to .distinct() on the pivot column. - * @since 1.6.0 - */ - @scala.annotation.varargs - def pivot(pivotColumn: String, values: Any*): GroupedData = { - val resolvedPivotColumn = Column(df.resolve(pivotColumn)) - pivot(resolvedPivotColumn, values.map(functions.lit): _*) + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = { + groupType match { + case GroupedData.GroupByType => + new GroupedData( + df, + groupingExprs, + GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + case _: GroupedData.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + } + + /** + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = { + pivot(pivotColumn, values.asScala) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 567bdddece80..a12fed3c0c6a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -282,4 +282,20 @@ public void testSampleBy() { Assert.assertEquals(1, actual[1].getLong(0)); Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); } + + @Test + public void pivot() { + DataFrame df = context.table("courseSales"); + Row[] actual = df.groupBy("year") + .pivot("course", Arrays.asList("dotNET", "Java")) + .agg(sum("earnings")).orderBy("year").collect(); + + Assert.assertEquals(2012, actual[0].getInt(0)); + Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01); + + Assert.assertEquals(2013, actual[1].getInt(0)); + Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 0c23d142670c..fc53aba68ebb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -25,7 +25,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot courses with literals") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil ) @@ -33,14 +33,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot year with literals") { checkAnswer( - courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")), + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } test("pivot courses with literals and multiple aggregations") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + courseSales.groupBy($"year") + .pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings"), avg($"earnings")), Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil @@ -49,14 +50,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot year with string values (cast)") { checkAnswer( - courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"), + courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } test("pivot year with int values") { checkAnswer( - courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"), + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } @@ -64,22 +65,22 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot courses with no values") { // Note Java comes before dotNet in sorted order checkAnswer( - courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil ) } test("pivot year with no values") { checkAnswer( - courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } - test("pivot max values inforced") { + test("pivot max values enforced") { sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) - intercept[RuntimeException]( - courseSales.groupBy($"year").pivot($"course") + intercept[AnalysisException]( + courseSales.groupBy("year").pivot("course") ) sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index abad0d7eaaed..83c63e04f344 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self => person salary complexData + courseSales } } From e6dd237463d2de8c506f0735dfdb3f43e8122513 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Nov 2015 15:08:02 -0600 Subject: [PATCH 1531/1824] [SPARK-11929][CORE] Make the repl log4j configuration override the root logger. In the default Spark distribution, there are currently two separate log4j config files, with different default values for the root logger, so that when running the shell you have a different default log level. This makes the shell more usable, since the logs don't overwhelm the output. But if you install a custom log4j.properties, you lose that, because then it's going to be used no matter whether you're running a regular app or the shell. With this change, the overriding of the log level is done differently; the log level repl's main class (org.apache.spark.repl.Main) is used to define the root logger's level when running the shell, defaulting to WARN if it's not set explicitly. On a somewhat related change, the shell output about the "sc" variable was changed a bit to contain a little more useful information about the application, since when the root logger's log level is WARN, that information is never shown to the user. Author: Marcelo Vanzin Closes #9816 from vanzin/shell-logging. --- conf/log4j.properties.template | 5 +++ .../spark/log4j-defaults-repl.properties | 33 -------------- .../apache/spark/log4j-defaults.properties | 5 +++ .../main/scala/org/apache/spark/Logging.scala | 45 ++++++++++--------- .../apache/spark/repl/SparkILoopInit.scala | 21 ++++----- .../org/apache/spark/repl/SparkILoop.scala | 25 ++++++----- 6 files changed, 57 insertions(+), 77 deletions(-) delete mode 100644 core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index f3046be54d7c..9809b0c82848 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -22,6 +22,11 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties deleted file mode 100644 index c85abc35b93b..000000000000 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the console -log4j.rootCategory=WARN, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.target=System.err -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - -# Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR -log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO -log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO - -# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support -log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL -log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index d44cc85dcbd8..0750488e4adf 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -22,6 +22,11 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 69f6e06ee005..e35e158c7e8a 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.log4j.{LogManager, PropertyConfigurator} +import org.apache.log4j.{Level, LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder @@ -119,30 +119,31 @@ trait Logging { val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + // scalastyle:off println if (!log4j12Initialized) { - // scalastyle:off println - if (Utils.isInInterpreter) { - val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" - Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") - System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") - case None => - System.err.println(s"Spark was unable to load $replDefaultLogProps") - } - } else { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") - } + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") } - // scalastyle:on println } + + if (Utils.isInInterpreter) { + // Use the repl's main class to define the default log level when running the shell, + // overriding the root logger's config if they're different. + val rootLogger = LogManager.getRootLogger() + val replLogger = LogManager.getLogger("org.apache.spark.repl.Main") + val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) + if (replLevel != rootLogger.getEffectiveLevel()) { + System.err.printf("Setting default log level to \"%s\".\n", replLevel) + System.err.println("To adjust logging level use sc.setLogLevel(newLevel).") + rootLogger.setLevel(replLevel) + } + } + // scalastyle:on println } Logging.initialized = true diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index bd3314d94eed..99e1e1df33fd 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -123,18 +123,19 @@ private[repl] trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.interp.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) command(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) command("import org.apache.spark.SparkContext._") command("import sqlContext.implicits._") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 33d262558b1f..e91139fb29f6 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -37,18 +37,19 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { processLine(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) processLine(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) processLine("import org.apache.spark.SparkContext._") processLine("import sqlContext.implicits._") @@ -85,7 +86,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = sparkStandardCommands - /** + /** * We override `loadFiles` because we need to initialize Spark *before* the REPL * sees any files, so that the Spark context is visible in those files. This is a bit of a * hack, but there isn't another hook available to us at this point. @@ -98,7 +99,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) object SparkILoop { - /** + /** * Creates an interpreter loop with default settings and feeds * the given code to it as input. */ From 58d9b260556a89a3d0832d583acafba1df7c6751 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 24 Nov 2015 14:33:28 -0800 Subject: [PATCH 1532/1824] [SPARK-11805] free the array in UnsafeExternalSorter during spilling After calling spill() on SortedIterator, the array inside InMemorySorter is not needed, it should be freed during spilling, this could help to join multiple tables with limited memory. Author: Davies Liu Closes #9793 from davies/free_array. --- .../unsafe/sort/UnsafeExternalSorter.java | 10 +++--- .../unsafe/sort/UnsafeInMemorySorter.java | 31 ++++++++----------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9a7b2ad06cab..2e4031267473 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -468,6 +468,12 @@ public long spill() throws IOException { } allocatedPages.clear(); } + + // in-memory sorter will not be used after spilling + assert(inMemSorter != null); + released += inMemSorter.getMemoryUsage(); + inMemSorter.free(); + inMemSorter = null; return released; } } @@ -489,10 +495,6 @@ public void loadNext() throws IOException { } upstream = nextUpstream; nextUpstream = null; - - assert(inMemSorter != null); - inMemSorter.free(); - inMemSorter = null; } numRecords--; upstream.loadNext(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index a218ad4623f4..dce1f15a2963 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -108,6 +108,7 @@ public UnsafeInMemorySorter( */ public void free() { consumer.freeArray(array); + array = null; } public void reset() { @@ -160,28 +161,22 @@ public void insertRecord(long recordPointer, long keyPrefix) { pos++; } - public static final class SortedIterator extends UnsafeSorterIterator { + public final class SortedIterator extends UnsafeSorterIterator { - private final TaskMemoryManager memoryManager; - private final int sortBufferInsertPosition; - private final LongArray sortBuffer; - private int position = 0; + private final int numRecords; + private int position; private Object baseObject; private long baseOffset; private long keyPrefix; private int recordLength; - private SortedIterator( - TaskMemoryManager memoryManager, - int sortBufferInsertPosition, - LongArray sortBuffer) { - this.memoryManager = memoryManager; - this.sortBufferInsertPosition = sortBufferInsertPosition; - this.sortBuffer = sortBuffer; + private SortedIterator(int numRecords) { + this.numRecords = numRecords; + this.position = 0; } public SortedIterator clone () { - SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer); + SortedIterator iter = new SortedIterator(numRecords); iter.position = position; iter.baseObject = baseObject; iter.baseOffset = baseOffset; @@ -192,21 +187,21 @@ public SortedIterator clone () { @Override public boolean hasNext() { - return position < sortBufferInsertPosition; + return position / 2 < numRecords; } public int numRecordsLeft() { - return (sortBufferInsertPosition - position) / 2; + return numRecords - position / 2; } @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = sortBuffer.get(position); + final long recordPointer = array.get(position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); - keyPrefix = sortBuffer.get(position + 1); + keyPrefix = array.get(position + 1); position += 2; } @@ -229,6 +224,6 @@ public void loadNext() { */ public SortedIterator getSortedIterator() { sorter.sort(array, 0, pos / 2, sortComparator); - return new SortedIterator(memoryManager, pos, array); + return new SortedIterator(pos / 2); } } From 34ca392da7097a1fbe48cd6c3ebff51453ca26ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 14:51:01 -0800 Subject: [PATCH 1533/1824] Added a line of comment to explain why the extra sort exists in pivot. --- sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index ee7150cbbfbc..abd531c4ba54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -304,7 +304,7 @@ class GroupedData protected[sql]( // Get the distinct values of the column and sort them so its consistent val values = df.select(pivotColumn) .distinct() - .sort(pivotColumn) + .sort(pivotColumn) // ensure that the output columns are in a consistent logical order .map(_.get(0)) .take(maxValues + 1) .toSeq From c7f95df5c6d8eb2e6f11cf58b704fea34326a5f2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 24 Nov 2015 14:59:14 -0800 Subject: [PATCH 1534/1824] [SPARK-11783][SQL] Fixes execution Hive client when using remote Hive metastore When using remote Hive metastore, `hive.metastore.uris` is set to the metastore URI. However, it overrides `javax.jdo.option.ConnectionURL` unexpectedly, thus the execution Hive client connects to the actual remote Hive metastore instead of the Derby metastore created in the temporary directory. Cleaning this configuration for the execution Hive client fixes this issue. Author: Cheng Lian Closes #9895 from liancheng/spark-11783.clean-remote-metastore-config. --- .../org/apache/spark/sql/hive/HiveContext.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c0bb5af7d5c8..8a4264194ae8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -736,6 +736,21 @@ private[hive] object HiveContext { s"jdbc:derby:;databaseName=${localMetastore.getAbsolutePath};create=true") propMap.put("datanucleus.rdbms.datastoreAdapterClassName", "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + + // SPARK-11783: When "hive.metastore.uris" is set, the metastore connection mode will be + // remote (https://cwiki.apache.org/confluence/display/Hive/AdminManual+MetastoreAdmin + // mentions that "If hive.metastore.uris is empty local mode is assumed, remote otherwise"). + // Remote means that the metastore server is running in its own process. + // When the mode is remote, configurations like "javax.jdo.option.ConnectionURL" will not be + // used (because they are used by remote metastore server that talks to the database). + // Because execution Hive should always connects to a embedded derby metastore. + // We have to remove the value of hive.metastore.uris. So, the execution Hive client connects + // to the actual embedded derby metastore instead of the remote metastore. + // You can search HiveConf.ConfVars.METASTOREURIS in the code of HiveConf (in Hive's repo). + // Then, you will find that the local metastore mode is only set to true when + // hive.metastore.uris is not set. + propMap.put(ConfVars.METASTOREURIS.varname, "") + propMap.toMap } From 238ae51b66ac12d15fba6aff061804004c5ca6cb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 24 Nov 2015 15:54:10 -0800 Subject: [PATCH 1535/1824] [SPARK-11914][SQL] Support coalesce and repartition in Dataset APIs This PR is to provide two common `coalesce` and `repartition` in Dataset APIs. After reading the comments of SPARK-9999, I am unclear about the plan for supporting re-partitioning in Dataset APIs. Currently, both RDD APIs and Dataframe APIs provide users such a flexibility to control the number of partitions. In most traditional RDBMS, they expose the number of partitions, the partitioning columns, the table partitioning methods to DBAs for performance tuning and storage planning. Normally, these parameters could largely affect the query performance. Since the actual performance depends on the workload types, I think it is almost impossible to automate the discovery of the best partitioning strategy for all the scenarios. I am wondering if Dataset APIs are planning to hide these APIs from users? Feel free to reject my PR if it does not match the plan. Thank you for your answers. marmbrus rxin cloud-fan Author: gatorsmile Closes #9899 from gatorsmile/coalesce. --- .../scala/org/apache/spark/sql/Dataset.scala | 19 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 07647508421a..17e2611790d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -152,6 +152,25 @@ class Dataset[T] private[sql]( */ def count(): Long = toDF().count() + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * @since 1.6.0 + */ + def repartition(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = true, _) + } + + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * @since 1.6.0 + */ + def coalesce(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = false, _) + } + /* *********************** * * Functional Operations * * *********************** */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 13eede1b17d8..c253fdbb8c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -52,6 +52,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.takeAsList(1).get(0) == item) } + test("coalesce, repartition") { + val data = (1 to 100).map(i => ClassData(i.toString, i)) + val ds = data.toDS() + + assert(ds.repartition(10).rdd.partitions.length == 10) + checkAnswer( + ds.repartition(10), + data: _*) + + assert(ds.coalesce(1).rdd.partitions.length == 1) + checkAnswer( + ds.coalesce(1), + data: _*) + } + test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") checkAnswer( From 25bbd3c16e8e8be4d2c43000223d54650e9a3696 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 18:16:07 -0800 Subject: [PATCH 1536/1824] [SPARK-11967][SQL] Consistent use of varargs for multiple paths in DataFrameReader This patch makes it consistent to use varargs in all DataFrameReader methods, including Parquet, JSON, text, and the generic load function. Also added a few more API tests for the Java API. Author: Reynold Xin Closes #9945 from rxin/SPARK-11967. --- python/pyspark/sql/readwriter.py | 19 ++++++---- .../apache/spark/sql/DataFrameReader.scala | 36 +++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 23 ++++++++++++ sql/core/src/test/resources/text-suite2.txt | 1 + .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 5 files changed, 66 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/test/resources/text-suite2.txt diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e8f0d7ec7703..2e75f0c8a182 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -109,7 +109,7 @@ def options(self, **options): def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. - :param path: optional string for file-system backed data sources. + :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options @@ -118,6 +118,7 @@ def load(self, path=None, format=None, schema=None, **options): ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', ... 'python/test_support/sql/people1.json']) >>> df.dtypes @@ -130,10 +131,8 @@ def load(self, path=None, format=None, schema=None, **options): self.options(**options) if path is not None: if type(path) == list: - paths = path - gateway = self._sqlContext._sc._gateway - jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) - return self._df(self._jreader.load(jpaths)) + return self._df( + self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load(path)) else: @@ -175,6 +174,8 @@ def json(self, path, schema=None): self.schema(schema) if isinstance(path, basestring): return self._df(self._jreader.json(path)) + elif type(path) == list: + return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): return self._df(self._jreader.json(path._jrdd)) else: @@ -205,16 +206,20 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, path): + def text(self, paths): """Loads a text file and returns a [[DataFrame]] with a single string column named "text". Each line in the text file is a new row in the resulting DataFrame. + :param paths: string, or list of strings, for input path(s). + >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] """ - return self._df(self._jreader.text(path)) + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) @since(1.5) def orc(self, path): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index dcb3737b70fb..3ed1e55adec6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,17 +24,17 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.util.StringUtils +import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.SqlParser import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} /** * :: Experimental :: @@ -104,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def load(path: String): DataFrame = { option("path", path).load() } @@ -130,7 +131,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.6.0 */ - def load(paths: Array[String]): DataFrame = { + @scala.annotation.varargs + def load(paths: String*): DataFrame = { option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() } @@ -236,11 +238,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers * (e.g. 00012)
  • * - * @param path input path * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def json(path: String): DataFrame = format("json").load(path) + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • + *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • + *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
  • + *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
  • + * + * @since 1.6.0 + */ + def json(paths: String*): DataFrame = format("json").load(paths : _*) + /** * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and * returns the result as a [[DataFrame]]. @@ -328,10 +349,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * sqlContext.read().text("/path/to/spark/README.md") * }}} * - * @param path input path + * @param paths input path * @since 1.6.0 */ - def text(path: String): DataFrame = format("text").load(path) + @scala.annotation.varargs + def text(paths: String*): DataFrame = format("text").load(paths : _*) /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index a12fed3c0c6a..8e0b2dbca4a9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -298,4 +298,27 @@ public void pivot() { Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); } + + public void testGenericLoad() { + DataFrame df1 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); + } + + @Test + public void testTextLoad() { + DataFrame df1 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); + } } diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/text-suite2.txt new file mode 100644 index 000000000000..f9d498c80493 --- /dev/null +++ b/sql/core/src/test/resources/text-suite2.txt @@ -0,0 +1 @@ +This is another file for testing multi path loading. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dd6d06512ff6..76e9648aa753 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -897,7 +897,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val dir2 = new File(dir, "dir2").getCanonicalPath df2.write.format("json").save(dir2) - checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)), + checkAnswer(sqlContext.read.format("json").load(dir1, dir2), Row(1, 22) :: Row(2, 23) :: Nil) checkAnswer(sqlContext.read.format("json").load(dir1), From 4d6bbbc03ddb6650b00eb638e4876a196014c19c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 18:58:55 -0800 Subject: [PATCH 1537/1824] [SPARK-11947][SQL] Mark deprecated methods with "This will be removed in Spark 2.0." Also fixed some documentation as I saw them. Author: Reynold Xin Closes #9930 from rxin/SPARK-11947. --- project/MimaExcludes.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 20 +++-- .../org/apache/spark/sql/DataFrame.scala | 72 +++++++++------ .../scala/org/apache/spark/sql/Dataset.scala | 1 + .../org/apache/spark/sql/SQLContext.scala | 88 ++++++++++--------- .../org/apache/spark/sql/SQLImplicits.scala | 25 +++++- .../sql/{ => execution}/SparkSQLParser.scala | 15 ++-- .../org/apache/spark/sql/functions.scala | 52 ++++++----- .../SimpleTextHadoopFsRelationSuite.scala | 6 +- 9 files changed, 172 insertions(+), 110 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/SparkSQLParser.scala (89%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bb45d1bb1214..54a9ad956d11 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -108,7 +108,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") ) ++ Seq( // SPARK-11485 ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 30c554a85e69..b3cd9e1eff14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -42,7 +42,8 @@ private[sql] object Column { /** * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. - * @since 1.6.0 + * To create a [[TypedColumn]], use the `as` function on a [[Column]]. + * * @tparam T The input type expected for this expression. Can be `Any` if the expression is type * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. @@ -51,7 +52,8 @@ private[sql] object Column { */ class TypedColumn[-T, U]( expr: Expression, - private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) { + private[sql] val encoder: ExpressionEncoder[U]) + extends Column(expr) { /** * Inserts the specific input type and schema into any expressions that are expected to operate @@ -61,12 +63,11 @@ class TypedColumn[-T, U]( inputEncoder: ExpressionEncoder[_], schema: Seq[Attribute]): TypedColumn[T, U] = { val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] - new TypedColumn[T, U] (expr transform { - case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy( - aEncoder = Some(boundEncoder), - children = schema) - }, encoder) + new TypedColumn[T, U]( + expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy(aEncoder = Some(boundEncoder), children = schema) + }, + encoder) } } @@ -691,8 +692,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops * @since 1.3.0 + * @deprecated As of 1.5.0. Use isin. This will be removed in Spark 2.0. */ - @deprecated("use isin", "1.5.0") + @deprecated("use isin. This will be removed in Spark 2.0.", "1.5.0") @scala.annotation.varargs def in(list: Any*): Column = isin(list : _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5586fc994b98..5eca1db9525e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1713,9 +1713,9 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `toDF()`. + * @deprecated As of 1.3.0, replaced by `toDF()`. This will be removed in Spark 2.0. */ - @deprecated("use toDF", "1.3.0") + @deprecated("Use toDF. This will be removed in Spark 2.0.", "1.3.0") def toSchemaRDD: DataFrame = this /** @@ -1725,9 +1725,9 @@ class DataFrame private[sql]( * given name; if you pass `false`, it will throw if the table already * exists. * @group output - * @deprecated As of 1.340, replaced by `write().jdbc()`. + * @deprecated As of 1.340, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write w.jdbc(url, table, new Properties) @@ -1744,9 +1744,9 @@ class DataFrame private[sql]( * the RDD in order via the simple statement * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. * @group output - * @deprecated As of 1.4.0, replaced by `write().jdbc()`. + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { val w = if (overwrite) write.mode(SaveMode.Overwrite) else write.mode(SaveMode.Append) w.jdbc(url, table, new Properties) @@ -1757,9 +1757,9 @@ class DataFrame private[sql]( * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output - * @deprecated As of 1.4.0, replaced by `write().parquet()`. + * @deprecated As of 1.4.0, replaced by `write().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.parquet(path)", "1.4.0") + @deprecated("Use write.parquet(path). This will be removed in Spark 2.0.", "1.4.0") def saveAsParquetFile(path: String): Unit = { write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } @@ -1782,8 +1782,9 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.saveAsTable(tableName). This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String): Unit = { write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) } @@ -1805,8 +1806,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(mode).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { write.mode(mode).saveAsTable(tableName) } @@ -1829,8 +1832,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { write.format(source).saveAsTable(tableName) } @@ -1853,8 +1858,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).saveAsTable(tableName) } @@ -1877,9 +1884,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1907,9 +1915,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1923,9 +1932,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @deprecated As of 1.4.0, replaced by `write().save(path)`. + * @deprecated As of 1.4.0, replaced by `write().save(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use write.save(path)", "1.4.0") + @deprecated("Use write.save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String): Unit = { write.save(path) } @@ -1935,8 +1944,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default. * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).save(path)", "1.4.0") + @deprecated("Use write.mode(mode).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, mode: SaveMode): Unit = { write.mode(mode).save(path) } @@ -1946,8 +1956,9 @@ class DataFrame private[sql]( * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).save(path)", "1.4.0") + @deprecated("Use write.format(source).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String): Unit = { write.format(source).save(path) } @@ -1957,8 +1968,10 @@ class DataFrame private[sql]( * [[SaveMode]] specified by mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).save(path). " + + "This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).save(path) } @@ -1969,8 +1982,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1985,8 +2000,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1994,14 +2011,15 @@ class DataFrame private[sql]( write.format(source).mode(mode).options(options).save() } - /** * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) } @@ -2012,8 +2030,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String): Unit = { write.mode(SaveMode.Append).insertInto(tableName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 17e2611790d5..dd84b8bc11e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType /** + * :: Experimental :: * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel * using functional or relational operations. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 39471d2fb79a..46bf544fd885 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -942,33 +942,33 @@ class SQLContext private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } @@ -978,9 +978,9 @@ class SQLContext private[sql]( * [[DataFrame]] if no paths are passed in. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().parquet()`. + * @deprecated As of 1.4.0, replaced by `read().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.parquet()", "1.4.0") + @deprecated("Use read.parquet(). This will be removed in Spark 2.0.", "1.4.0") @scala.annotation.varargs def parquetFile(paths: String*): DataFrame = { if (paths.isEmpty) { @@ -995,9 +995,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String): DataFrame = { read.json(path) } @@ -1007,18 +1007,18 @@ class SQLContext private[sql]( * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, schema: StructType): DataFrame = { read.schema(schema).json(path) } /** * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(path) } @@ -1029,9 +1029,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String]): DataFrame = read.json(json) /** @@ -1040,9 +1040,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) /** @@ -1050,9 +1050,9 @@ class SQLContext private[sql]( * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1062,9 +1062,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1074,9 +1074,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1086,9 +1086,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1098,9 +1098,9 @@ class SQLContext private[sql]( * using the default data source configured by spark.sql.sources.default. * * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().load(path)`. + * @deprecated As of 1.4.0, replaced by `read().load(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use read.load(path)", "1.4.0") + @deprecated("Use read.load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String): DataFrame = { read.load(path) } @@ -1110,8 +1110,9 @@ class SQLContext private[sql]( * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).load(path)", "1.4.0") + @deprecated("Use read.format(source).load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String, source: String): DataFrame = { read.format(source).load(path) } @@ -1122,8 +1123,10 @@ class SQLContext private[sql]( * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: java.util.Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1135,7 +1138,8 @@ class SQLContext private[sql]( * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1148,7 +1152,8 @@ class SQLContext private[sql]( * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() @@ -1162,7 +1167,8 @@ class SQLContext private[sql]( * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() } @@ -1172,9 +1178,9 @@ class SQLContext private[sql]( * url named table. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String): DataFrame = { read.jdbc(url, table, new Properties) } @@ -1190,9 +1196,9 @@ class SQLContext private[sql]( * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc( url: String, table: String, @@ -1210,9 +1216,9 @@ class SQLContext private[sql]( * of the [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { read.jdbc(url, table, theParts, new Properties) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 25ffdcde1771..6735d02954b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -30,19 +30,38 @@ import org.apache.spark.unsafe.types.UTF8String /** * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + * + * @since 1.6.0 */ abstract class SQLImplicits { + protected def _sqlContext: SQLContext + /** @since 1.6.0 */ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + /** @since 1.6.0 */ + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() /** @@ -84,9 +103,9 @@ abstract class SQLImplicits { DataFrameHolder(_sqlContext.createDataFrame(data)) } - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. + // Do NOT add more implicit conversions for primitive types. + // They are likely to break source compatibility by making existing implicit conversions + // ambiguous. In particular, RDD[Double] is dangerous because of [[DoubleRDDFunctions]]. /** * Creates a single column DataFrame from an RDD[Int]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala similarity index 89% rename from sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index ea8fce6ca9cf..b3e8d0d84937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -15,24 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StringType - /** * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. * * @param fallback A function that parses an input string to a logical plan */ -private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { +class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { // A parser for the key-value part of the "SET [key = [value ]]" syntax private object SetCommandParser extends RegexParsers { @@ -100,14 +99,14 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr case _ ~ dbName => ShowTablesCommand(dbName) } | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { - case Some(f) => ShowFunctions(f._1, Some(f._2)) - case None => ShowFunctions(None, None) + case Some(f) => logical.ShowFunctions(f._1, Some(f._2)) + case None => logical.ShowFunctions(None, None) } ) private lazy val desc: Parser[LogicalPlan] = DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { - case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined) + case isExtended ~ functionName => logical.DescribeFunction(functionName, isExtended.isDefined) } private lazy val others: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6137ce3a70fd..77dd5bc72508 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql - - import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try @@ -39,11 +37,11 @@ import org.apache.spark.util.Utils * "bridge" methods due to the use of covariant return types. * * {{{ - * In LegacyFunctions: - * public abstract org.apache.spark.sql.Column avg(java.lang.String); + * // In LegacyFunctions: + * public abstract org.apache.spark.sql.Column avg(java.lang.String); * - * In functions: - * public static org.apache.spark.sql.TypedColumn avg(...); + * // In functions: + * public static org.apache.spark.sql.TypedColumn avg(...); * }}} * * This allows us to use the same functions both in typed [[Dataset]] operations and untyped @@ -2528,8 +2526,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = withExpr { ScalaUDF(f, returnType, Seq()) } @@ -2541,8 +2540,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr)) } @@ -2554,8 +2554,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } @@ -2567,8 +2568,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } @@ -2580,8 +2582,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } @@ -2593,8 +2596,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } @@ -2606,8 +2610,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } @@ -2619,8 +2624,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } @@ -2632,8 +2638,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } @@ -2644,9 +2651,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } @@ -2657,9 +2665,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } @@ -2700,9 +2709,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.4.0 - * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF. + * This will be removed in Spark 2.0. */ - @deprecated("Use callUDF", "1.5.0") + @deprecated("Use callUDF. This will be removed in Spark 2.0.", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = withExpr { // Note: we avoid using closures here because on file systems that are case-insensitive, the // compiled class file for the closure here will conflict with the one in callUDF (upper case). diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 81af684ba0bf..b554d135e4b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -80,7 +80,11 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat private var partitionedDF: DataFrame = _ - private val partitionedDataSchema: StructType = StructType('a.int :: 'b.int :: 'c.string :: Nil) + private val partitionedDataSchema: StructType = + new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + .add("c", StringType) protected override def beforeAll(): Unit = { this.tempPath = Utils.createTempDir() From a5d988763319f63a8e2b58673dd4f9098f17c835 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 24 Nov 2015 20:58:47 -0800 Subject: [PATCH 1538/1824] [STREAMING][FLAKY-TEST] Catch execution context race condition in `FileBasedWriteAheadLog.close()` There is a race condition in `FileBasedWriteAheadLog.close()`, where if delete's of old log files are in progress, the write ahead log may close, and result in a `RejectedExecutionException`. This is okay, and should be handled gracefully. Example test failures: https://amplab.cs.berkeley.edu/jenkins/job/Spark-1.6-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop1.0,label=spark-test/95/testReport/junit/org.apache.spark.streaming.util/BatchedWriteAheadLogWithCloseFileAfterWriteSuite/BatchedWriteAheadLog___clean_old_logs/ The reason the test fails is in `afterEach`, `writeAheadLog.close` is called, and there may still be async deletes in flight. tdas zsxwing Author: Burak Yavuz Closes #9953 from brkyvz/flaky-ss. --- .../streaming/util/FileBasedWriteAheadLog.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 72705f1a9c01..f5165f7c3912 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer -import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} import java.util.{Iterator => JIterator} import scala.collection.JavaConverters._ @@ -176,10 +176,16 @@ private[streaming] class FileBasedWriteAheadLog( } oldLogFiles.foreach { logInfo => if (!executionContext.isShutdown) { - val f = Future { deleteFile(logInfo) }(executionContext) - if (waitForCompletion) { - import scala.concurrent.duration._ - Await.ready(f, 1 second) + try { + val f = Future { deleteFile(logInfo) }(executionContext) + if (waitForCompletion) { + import scala.concurrent.duration._ + Await.ready(f, 1 second) + } + } catch { + case e: RejectedExecutionException => + logWarning("Execution context shutdown before deleting old WriteAheadLogs. " + + "This would not affect recovery correctness.", e) } } } From 151d7c2baf18403e6e59e97c80c8bcded6148038 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 21:30:53 -0800 Subject: [PATCH 1539/1824] [SPARK-10621][SQL] Consistent naming for functions in SQL, Python, Scala Author: Reynold Xin Closes #9948 from rxin/SPARK-10621. --- python/pyspark/sql/functions.py | 111 +++++++++++++--- .../org/apache/spark/sql/functions.scala | 124 ++++++++++++++---- 2 files changed, 196 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a1ca723bbd7a..e3786e0fa5fb 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -150,18 +150,18 @@ def _(): _window_functions = { 'rowNumber': - """returns a sequential number starting at 1 within a window partition. - - This is equivalent to the ROW_NUMBER function in SQL.""", + """.. note:: Deprecated in 1.6, use row_number instead.""", + 'row_number': + """returns a sequential number starting at 1 within a window partition.""", 'denseRank': + """.. note:: Deprecated in 1.6, use dense_rank instead.""", + 'dense_rank': """returns the rank of rows within a window partition, without any gaps. The difference between rank and denseRank is that denseRank leaves no gaps in ranking sequence when there are ties. That is, if you were ranking a competition using denseRank and had three people tie for second place, you would say that all three were in second - place and that the next person came in third. - - This is equivalent to the DENSE_RANK function in SQL.""", + place and that the next person came in third.""", 'rank': """returns the rank of rows within a window partition. @@ -172,14 +172,14 @@ def _(): This is equivalent to the RANK function in SQL.""", 'cumeDist': + """.. note:: Deprecated in 1.6, use cume_dist instead.""", + 'cume_dist': """returns the cumulative distribution of values within a window partition, - i.e. the fraction of rows that are below the current row. - - This is equivalent to the CUME_DIST function in SQL.""", + i.e. the fraction of rows that are below the current row.""", 'percentRank': - """returns the relative rank (i.e. percentile) of rows within a window partition. - - This is equivalent to the PERCENT_RANK function in SQL.""", + """.. note:: Deprecated in 1.6, use percent_rank instead.""", + 'percent_rank': + """returns the relative rank (i.e. percentile) of rows within a window partition.""", } for _name, _doc in _functions.items(): @@ -189,7 +189,7 @@ def _(): for _name, _doc in _binary_mathfunctions.items(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) for _name, _doc in _window_functions.items(): - globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) + globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) for _name, _doc in _functions_1_6.items(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) del _name, _doc @@ -288,6 +288,38 @@ def countDistinct(col, *cols): @since(1.4) def monotonicallyIncreasingId(): + """ + .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. + """ + return monotonically_increasing_id() + + +@since(1.6) +def input_file_name(): + """Creates a string column for the file name of the current Spark task. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.input_file_name()) + + +@since(1.6) +def isnan(col): + """An expression that returns true iff the column is NaN. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnan(_to_java_column(col))) + + +@since(1.6) +def isnull(col): + """An expression that returns true iff the column is null. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnull(_to_java_column(col))) + + +@since(1.6) +def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. @@ -300,11 +332,21 @@ def monotonicallyIncreasingId(): 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) - >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + >>> df0.select(monotonically_increasing_id().alias('id')).collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.monotonicallyIncreasingId()) + return Column(sc._jvm.functions.monotonically_increasing_id()) + + +@since(1.6) +def nanvl(col1, col2): + """Returns col1 if it is not NaN, or col2 if col1 is NaN. + + Both inputs should be floating point columns (DoubleType or FloatType). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) @since(1.4) @@ -382,15 +424,23 @@ def shiftRightUnsigned(col, numBits): @since(1.4) def sparkPartitionId(): + """ + .. note:: Deprecated in 1.6, use spark_partition_id instead. + """ + return spark_partition_id() + + +@since(1.6) +def spark_partition_id(): """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. - >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.sparkPartitionId()) + return Column(sc._jvm.functions.spark_partition_id()) @since(1.5) @@ -1410,6 +1460,33 @@ def explode(col): return Column(jc) +@since(1.6) +def get_json_object(col, path): + """ + Extracts json object from a json string based on json path specified, and returns json string + of the extracted json object. It will return null if the input json string is invalid. + + :param col: string column in json format + :param path: path to the json object to extract + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) + return Column(jc) + + +@since(1.6) +def json_tuple(col, fields): + """Creates a new row for a json column according to the given field names. + + :param col: string column in json format + :param fields: list of fields to extract + + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.json_tuple(_to_java_column(col), fields) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 77dd5bc72508..276c5dfc8b06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -472,6 +472,13 @@ object functions extends LegacyFunctions { // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `cume_dist`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def cumeDist(): Column = cume_dist() + /** * Window function: returns the cumulative distribution of values within a window partition, * i.e. the fraction of rows that are below the current row. @@ -481,13 +488,17 @@ object functions extends LegacyFunctions { * cumeDist(x) = number of values before (and including) x / N * }}} * - * - * This is equivalent to the CUME_DIST function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def cumeDist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + def cume_dist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `dense_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use dense_rank. This will be removed in Spark 2.0.", "1.6.0") + def denseRank(): Column = dense_rank() /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -497,12 +508,10 @@ object functions extends LegacyFunctions { * and had three people tie for second place, you would say that all three were in second * place and that the next person came in third. * - * This is equivalent to the DENSE_RANK function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def denseRank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } + def dense_rank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -620,6 +629,13 @@ object functions extends LegacyFunctions { */ def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) } + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `percent_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use percent_rank. This will be removed in Spark 2.0.", "1.6.0") + def percentRank(): Column = percent_rank() + /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. * @@ -631,9 +647,9 @@ object functions extends LegacyFunctions { * This is equivalent to the PERCENT_RANK function in SQL. * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def percentRank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } + def percent_rank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } /** * Window function: returns the rank of rows within a window partition. @@ -650,15 +666,20 @@ object functions extends LegacyFunctions { */ def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) } + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `row_number`. This will be removed in Spark 2.0. + */ + @deprecated("Use row_number. This will be removed in Spark 2.0.", "1.6.0") + def rowNumber(): Column = row_number() + /** * Window function: returns a sequential number starting at 1 within a window partition. * - * This is equivalent to the ROW_NUMBER function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def rowNumber(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } + def row_number(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -720,20 +741,43 @@ object functions extends LegacyFunctions { @scala.annotation.varargs def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `input_file_name`. This will be removed in Spark 2.0. + */ + @deprecated("Use input_file_name. This will be removed in Spark 2.0.", "1.6.0") + def inputFileName(): Column = input_file_name() + /** * Creates a string column for the file name of the current Spark task. * * @group normal_funcs + * @since 1.6.0 */ - def inputFileName(): Column = withExpr { InputFileName() } + def input_file_name(): Column = withExpr { InputFileName() } + + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `isnan`. This will be removed in Spark 2.0. + */ + @deprecated("Use isnan. This will be removed in Spark 2.0.", "1.6.0") + def isNaN(e: Column): Column = isnan(e) /** * Return true iff the column is NaN. * * @group normal_funcs - * @since 1.5.0 + * @since 1.6.0 + */ + def isnan(e: Column): Column = withExpr { IsNaN(e.expr) } + + /** + * Return true iff the column is null. + * + * @group normal_funcs + * @since 1.6.0 */ - def isNaN(e: Column): Column = withExpr { IsNaN(e.expr) } + def isnull(e: Column): Column = withExpr { IsNull(e.expr) } /** * A column expression that generates monotonically increasing 64-bit integers. @@ -750,7 +794,24 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = withExpr { MonotonicallyIncreasingID() } + def monotonicallyIncreasingId(): Column = monotonically_increasing_id() + + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + * @since 1.6.0 + */ + def monotonically_increasing_id(): Column = withExpr { MonotonicallyIncreasingID() } /** * Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -825,15 +886,23 @@ object functions extends LegacyFunctions { */ def randn(): Column = randn(Utils.random.nextLong) + /** + * @group normal_funcs + * @since 1.4.0 + * @deprecated As of 1.6.0, replaced by `spark_partition_id`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def sparkPartitionId(): Column = spark_partition_id() + /** * Partition ID of the Spark task. * * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def sparkPartitionId(): Column = withExpr { SparkPartitionID() } + def spark_partition_id(): Column = withExpr { SparkPartitionID() } /** * Computes the square root of the specified float value. @@ -2305,6 +2374,17 @@ object functions extends LegacyFunctions { */ def explode(e: Column): Column = withExpr { Explode(e.expr) } + /** + * Extracts json object from a json string based on json path specified, and returns json string + * of the extracted json object. It will return null if the input json string is invalid. + * + * @group collection_funcs + * @since 1.6.0 + */ + def get_json_object(e: Column, path: String): Column = withExpr { + GetJsonObject(e.expr, lit(path).expr) + } + /** * Creates a new row for a json column according to the given field names. * @@ -2313,7 +2393,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def json_tuple(json: Column, fields: String*): Column = withExpr { - require(fields.length > 0, "at least 1 field name should be given.") + require(fields.nonEmpty, "at least 1 field name should be given.") JsonTuple(json.expr +: fields.map(Literal.apply)) } From 2169886883d33b33acf378ac42a626576b342df1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Nov 2015 23:13:01 -0800 Subject: [PATCH 1540/1824] [SPARK-11979][STREAMING] Empty TrackStateRDD cannot be checkpointed and recovered from checkpoint file This solves the following exception caused when empty state RDD is checkpointed and recovered. The root cause is that an empty OpenHashMapBasedStateMap cannot be deserialized as the initialCapacity is set to zero. ``` Job aborted due to stage failure: Task 0 in stage 6.0 failed 1 times, most recent failure: Lost task 0.0 in stage 6.0 (TID 20, localhost): java.lang.IllegalArgumentException: requirement failed: Invalid initial capacity at scala.Predef$.require(Predef.scala:233) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.(StateMap.scala:96) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.(StateMap.scala:86) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.readObject(StateMap.scala:291) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:606) at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1017) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1893) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:1990) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1915) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:370) at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:76) at org.apache.spark.serializer.DeserializationStream$$anon$1.getNext(Serializer.scala:181) at org.apache.spark.util.NextIterator.hasNext(NextIterator.scala:73) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:744) ``` Author: Tathagata Das Closes #9958 from tdas/SPARK-11979. --- .../spark/streaming/util/StateMap.scala | 19 +++++++----- .../spark/streaming/StateMapSuite.scala | 30 ++++++++++++------- .../streaming/rdd/TrackStateRDDSuite.scala | 10 +++++++ 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 34287c3e0090..3f139ad138c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -59,7 +59,7 @@ private[streaming] object StateMap { def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", DELTA_CHAIN_LENGTH_THRESHOLD) - new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) + new OpenHashMapBasedStateMap[K, S](deltaChainThreshold) } } @@ -79,7 +79,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( @transient @volatile var parentStateMap: StateMap[K, S], - initialCapacity: Int = 64, + initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD ) extends StateMap[K, S] { self => @@ -89,12 +89,14 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( deltaChainThreshold = deltaChainThreshold) def this(deltaChainThreshold: Int) = this( - initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) - @transient @volatile private var deltaMap = - new OpenHashMap[K, StateInfo[S]](initialCapacity) + require(initialCapacity >= 1, "Invalid initial capacity") + require(deltaChainThreshold >= 1, "Invalid delta chain threshold") + + @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity) /** Get the session data if it exists */ override def get(key: K): Option[S] = { @@ -284,9 +286,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( // Read the data of the parent map. Keep reading records, until the limiter is reached // First read the approximate number of records to expect and allocate properly size // OpenHashMap - val parentSessionStoreSizeHint = inputStream.readInt() + val parentStateMapSizeHint = inputStream.readInt() + val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY) val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( - initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + initialCapacity = newStateMapInitialCapacity, deltaChainThreshold) // Read the records until the limit marking object has been reached var parentSessionLoopDone = false @@ -338,4 +341,6 @@ private[streaming] object OpenHashMapBasedStateMap { class LimitMarker(val num: Int) extends Serializable val DELTA_CHAIN_LENGTH_THRESHOLD = 20 + + val DEFAULT_INITIAL_CAPACITY = 64 } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 48d3b41b66cb..c4a01eaea739 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -122,23 +122,27 @@ class StateMapSuite extends SparkFunSuite { test("OpenHashMapBasedStateMap - serializing and deserializing") { val map1 = new OpenHashMapBasedStateMap[Int, Int]() + testSerialization(map1, "error deserializing and serialized empty map") + map1.put(1, 100, 1) map1.put(2, 200, 2) + testSerialization(map1, "error deserializing and serialized map with data + no delta") val map2 = map1.copy() + // Do not test compaction + assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") + map2.put(3, 300, 3) map2.put(4, 400, 4) + testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") val map3 = map2.copy() + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") map3.put(3, 600, 3) map3.remove(2) - - // Do not test compaction - assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) - - val deser_map3 = Utils.deserialize[StateMap[Int, Int]]( - Utils.serialize(map3), Thread.currentThread().getContextClassLoader) - assertMap(deser_map3, map3, 1, "Deserialized map not same as original map") + testSerialization(map3, "error deserializing and serialized map with 2 delta + new data") } test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { @@ -156,11 +160,9 @@ class StateMapSuite extends SparkFunSuite { assert(map.deltaChainLength > deltaChainThreshold) assert(map.shouldCompact === true) - val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]]( - Utils.serialize(map), Thread.currentThread().getContextClassLoader) + val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map") assert(deser_map.deltaChainLength < deltaChainThreshold) assert(deser_map.shouldCompact === false) - assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map") } test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { @@ -265,6 +267,14 @@ class StateMapSuite extends SparkFunSuite { assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") } + private def testSerialization[MapType <: StateMap[Int, Int]]( + map: MapType, msg: String): MapType = { + val deserMap = Utils.deserialize[MapType]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assertMap(deserMap, map, 1, msg) + deserMap + } + // Assert whether all the data and operations on a state map matches that of a reference state map private def assertMap( mapToTest: StateMap[Int, Int], diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 0feb3af1abb0..3b2d43f2ce58 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -332,6 +332,16 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) } + test("checkpointing empty state RDD") { + val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int]( + sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) + emptyStateRDD.checkpoint() + assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]]( + emptyStateRDD.getCheckpointFile.get) + assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + } + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], From 2610e06124c7fc0b2b1cfb2e3050a35ab492fb71 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 25 Nov 2015 01:02:36 -0800 Subject: [PATCH 1541/1824] [SPARK-11970][SQL] Adding JoinType into JoinWith and support Sample in Dataset API Except inner join, maybe the other join types are also useful when users are using the joinWith function. Thus, added the joinType into the existing joinWith call in Dataset APIs. Also providing another joinWith interface for the cartesian-join-like functionality. Please provide your opinions. marmbrus rxin cloud-fan Thank you! Author: gatorsmile Closes #9921 from gatorsmile/joinWith. --- .../scala/org/apache/spark/sql/Dataset.scala | 45 +++++++++++++++---- .../org/apache/spark/sql/DatasetSuite.scala | 36 ++++++++++++--- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dd84b8bc11e2..97eb5b969280 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ - +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -83,7 +83,6 @@ class Dataset[T] private[sql]( /** * Returns the schema of the encoded form of the objects in this [[Dataset]]. - * * @since 1.6.0 */ def schema: StructType = resolvedTEncoder.schema @@ -185,7 +184,6 @@ class Dataset[T] private[sql]( * .transform(featurize) * .transform(...) * }}} - * * @since 1.6.0 */ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) @@ -453,6 +451,21 @@ class Dataset[T] private[sql]( c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] + /** + * Returns a new [[Dataset]] by sampling a fraction of records. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = + withPlan(Sample(0.0, fraction, withReplacement, seed, _)) + + /** + * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + /* **************** * * Set operations * * **************** */ @@ -511,13 +524,17 @@ class Dataset[T] private[sql]( * types as well as working with relational data where either side of the join has column * names in common. * + * @param other Right side of the join. + * @param condition Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. * @since 1.6.0 */ - def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { val left = this.logicalPlan val right = other.logicalPlan - val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr))) + val joined = sqlContext.executePlan(Join(left, right, joinType = + JoinType(joinType), Some(condition.expr))) val leftOutput = joined.analyzed.output.take(left.output.length) val rightOutput = joined.analyzed.output.takeRight(right.output.length) @@ -540,6 +557,18 @@ class Dataset[T] private[sql]( } } + /** + * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair + * where `condition` evaluates to true. + * + * @param other Right side of the join. + * @param condition Join expression. + * @since 1.6.0 + */ + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + joinWith(other, condition, "inner") + } + /* ************************** * * Gather to Driver Actions * * ************************** */ @@ -584,7 +613,6 @@ class Dataset[T] private[sql]( * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. - * * @since 1.6.0 */ def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() @@ -594,7 +622,6 @@ class Dataset[T] private[sql]( * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. - * * @since 1.6.0 */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c253fdbb8c99..7d539180ded9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -185,17 +185,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(1, 2).toDS().as("b") checkAnswer( - ds1.joinWith(ds2, $"a.value" === $"b.value"), + ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"), (1, 1), (2, 2)) } - test("joinWith, expression condition") { - val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() - val ds2 = Seq(("a", 1), ("b", 2)).toDS() + test("joinWith, expression condition, outer join") { + val nullInteger = null.asInstanceOf[Integer] + val nullString = null.asInstanceOf[String] + val ds1 = Seq(ClassNullableData("a", 1), + ClassNullableData("c", 3)).toDS() + val ds2 = Seq(("a", new Integer(1)), + ("b", new Integer(2))).toDS() checkAnswer( - ds1.joinWith(ds2, $"_1" === $"a"), - (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) + ds1.joinWith(ds2, $"_1" === $"a", "outer"), + (ClassNullableData("a", 1), ("a", new Integer(1))), + (ClassNullableData("c", 3), (nullString, nullInteger)), + (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) } test("joinWith tuple with primitive, expression") { @@ -225,7 +231,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), ((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))) - } test("groupBy function, keys") { @@ -367,6 +372,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1 -> "a", 2 -> "bc", 3 -> "d") } + test("sample with replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + 5, 10, 52, 73) + } + + test("sample without replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + 3, 17, 27, 58, 62) + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -440,6 +461,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { case class ClassData(a: String, b: Int) +case class ClassNullableData(a: String, b: Integer) /** * A class used to test serialization using encoders. This class throws exceptions when using From a0f1a11837bfffb76582499d36fbaf21a1d628cb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 25 Nov 2015 01:03:18 -0800 Subject: [PATCH 1542/1824] [SPARK-11981][SQL] Move implementations of methods back to DataFrame from Queryable Also added show methods to Dataset. Author: Reynold Xin Closes #9964 from rxin/SPARK-11981. --- .../org/apache/spark/sql/DataFrame.scala | 35 ++++++++- .../scala/org/apache/spark/sql/Dataset.scala | 77 ++++++++++++++++++- .../spark/sql/execution/Queryable.scala | 32 ++------ 3 files changed, 111 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5eca1db9525e..d8319b9a97fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -112,8 +112,8 @@ private[sql] object DataFrame { */ @Experimental class DataFrame private[sql]( - @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: QueryExecution) + @transient override val sqlContext: SQLContext, + @DeveloperApi @transient override val queryExecution: QueryExecution) extends Queryable with Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure @@ -282,6 +282,35 @@ class DataFrame private[sql]( */ def schema: StructType = queryExecution.analyzed.schema + /** + * Prints the schema to the console in a nice tree format. + * @group basic + * @since 1.3.0 + */ + // scalastyle:off println + override def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @group basic + * @since 1.3.0 + */ + override def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 1.3.0 + */ + override def explain(): Unit = explain(extended = false) + /** * Returns all column names and their data types as an array. * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 97eb5b969280..da4600133290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -61,8 +61,8 @@ import org.apache.spark.util.Utils */ @Experimental class Dataset[T] private[sql]( - @transient val sqlContext: SQLContext, - @transient val queryExecution: QueryExecution, + @transient override val sqlContext: SQLContext, + @transient override val queryExecution: QueryExecution, tEncoder: Encoder[T]) extends Queryable with Serializable { /** @@ -85,7 +85,25 @@ class Dataset[T] private[sql]( * Returns the schema of the encoded form of the objects in this [[Dataset]]. * @since 1.6.0 */ - def schema: StructType = resolvedTEncoder.schema + override def schema: StructType = resolvedTEncoder.schema + + /** + * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format. + * @since 1.6.0 + */ + override def printSchema(): Unit = toDF().printSchema() + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(extended: Boolean): Unit = toDF().explain(extended) + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(): Unit = toDF().explain() /* ************* * * Conversions * @@ -152,6 +170,59 @@ class Dataset[T] private[sql]( */ def count(): Long = toDF().count() + /** + * Displays the content of this [[Dataset]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * + * @since 1.6.0 + */ + def show(numRows: Int): Unit = show(numRows, truncate = true) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. + * + * @since 1.6.0 + */ + def show(): Unit = show(20) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + def show(numRows: Int, truncate: Boolean): Unit = toDF().show(numRows, truncate) + /** * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 321e2c783537..f2f5997d1b7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution +import scala.util.control.NonFatal + import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType -import scala.util.control.NonFatal - /** A trait that holds shared code between DataFrames and Datasets. */ private[sql] trait Queryable { def schema: StructType @@ -37,31 +37,9 @@ private[sql] trait Queryable { } } - /** - * Prints the schema to the console in a nice tree format. - * @group basic - * @since 1.3.0 - */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println + def printSchema(): Unit - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } + def explain(extended: Boolean): Unit - /** - * Only prints the physical plan to the console for debugging purposes. - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) + def explain(): Unit } From 63850026576b3ea7783f9d4b975171dc3cff6e4c Mon Sep 17 00:00:00 2001 From: Ashwin Swaroop Date: Wed, 25 Nov 2015 13:41:14 +0000 Subject: [PATCH 1543/1824] [SPARK-11686][CORE] Issue WARN when dynamic allocation is disabled due to spark.dynamicAllocation.enabled and spark.executor.instances both set Changed the log type to a 'warning' instead of 'info' as required. Author: Ashwin Swaroop Closes #9926 from ashwinswaroop/master. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e19ba113702c..2c10779f2b89 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -556,7 +556,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Optionally scale number of executors dynamically based on workload. Exposed for testing. val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { - logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") } _executorAllocationManager = From b9b6fbe89b6d1a890faa02c1a53bb670a6255362 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 25 Nov 2015 13:49:58 +0000 Subject: [PATCH 1544/1824] =?UTF-8?q?[SPARK-11860][PYSAPRK][DOCUMENTATION]?= =?UTF-8?q?=20Invalid=20argument=20specification=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …for registerFunction [Python] Straightforward change on the python doc Author: Jeff Zhang Closes #9901 from zjffdu/SPARK-11860. --- python/pyspark/sql/context.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 5a85ac31025e..a49c1b58d018 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -195,14 +195,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. + """Registers a python function (including lambda function) as a UDF + so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param samplingRatio: lambda function + :param f: python function :param returnType: a :class:`DataType` object >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) From 0a5aef753e70e93d7e56054f354a52e4d4e18932 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 25 Nov 2015 09:34:34 -0600 Subject: [PATCH 1545/1824] [SPARK-10666][SPARK-6880][CORE] Use properties from ActiveJob associated with a Stage This issue was addressed in https://github.com/apache/spark/pull/5494, but the fix in that PR, while safe in the sense that it will prevent the SparkContext from shutting down, misses the actual bug. The intent of `submitMissingTasks` should be understood as "submit the Tasks that are missing for the Stage, and run them as part of the ActiveJob identified by jobId". Because of a long-standing bug, the `jobId` parameter was never being used. Instead, we were trying to use the jobId with which the Stage was created -- which may no longer exist as an ActiveJob, hence the crash reported in SPARK-6880. The correct fix is to use the ActiveJob specified by the supplied jobId parameter, which is guaranteed to exist at the call sites of submitMissingTasks. This fix should be applied to all maintenance branches, since it has existed since 1.0. kayousterhout pankajarora12 Author: Mark Hamstra Author: Imran Rashid Closes #6291 from markhamstra/SPARK-6880. --- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../spark/scheduler/DAGSchedulerSuite.scala | 107 +++++++++++++++++- 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 77a184dfe4be..e01a9609b9a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -946,7 +946,9 @@ class DAGScheduler( stage.resetInternalAccumulators() } - val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull + // Use the scheduling pool, job group, description, etc. from an ActiveJob associated + // with this Stage + val properties = jobIdToActiveJob(jobId).properties runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -1047,7 +1049,7 @@ class DAGScheduler( stage.pendingPartitions ++= tasks.map(_.partitionId) logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 4d6b25455226..653d41fc053c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.util.Properties + import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal @@ -262,9 +264,10 @@ class DAGSchedulerSuite rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - listener: JobListener = jobListener): Int = { + listener: JobListener = jobListener, + properties: Properties = null): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener, properties)) jobId } @@ -1322,6 +1325,106 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + def checkJobPropertiesAndPriority(taskSet: TaskSet, expected: String, priority: Int): Unit = { + assert(taskSet.properties != null) + assert(taskSet.properties.getProperty("testProperty") === expected) + assert(taskSet.priority === priority) + } + + def launchJobsThatShareStageAndCancelFirst(): ShuffleDependency[Int, Int, Nothing] = { + val baseRdd = new MyRDD(sc, 1, Nil) + val shuffleDep1 = new ShuffleDependency(baseRdd, new HashPartitioner(1)) + val intermediateRdd = new MyRDD(sc, 1, List(shuffleDep1)) + val shuffleDep2 = new ShuffleDependency(intermediateRdd, new HashPartitioner(1)) + val finalRdd1 = new MyRDD(sc, 1, List(shuffleDep2)) + val finalRdd2 = new MyRDD(sc, 1, List(shuffleDep2)) + val job1Properties = new Properties() + val job2Properties = new Properties() + job1Properties.setProperty("testProperty", "job1") + job2Properties.setProperty("testProperty", "job2") + + // Run jobs 1 & 2, both referencing the same stage, then cancel job1. + // Note that we have to submit job2 before we cancel job1 to have them actually share + // *Stages*, and not just shuffle dependencies, due to skipped stages (at least until + // we address SPARK-10193.) + val jobId1 = submit(finalRdd1, Array(0), properties = job1Properties) + val jobId2 = submit(finalRdd2, Array(0), properties = job2Properties) + assert(scheduler.activeJobs.nonEmpty) + val testProperty1 = scheduler.jobIdToActiveJob(jobId1).properties.getProperty("testProperty") + + // remove job1 as an ActiveJob + cancel(jobId1) + + // job2 should still be running + assert(scheduler.activeJobs.nonEmpty) + val testProperty2 = scheduler.jobIdToActiveJob(jobId2).properties.getProperty("testProperty") + assert(testProperty1 != testProperty2) + // NB: This next assert isn't necessarily the "desired" behavior; it's just to document + // the current behavior. We've already submitted the TaskSet for stage 0 based on job1, but + // even though we have cancelled that job and are now running it because of job2, we haven't + // updated the TaskSet's properties. Changing the properties to "job2" is likely the more + // correct behavior. + val job1Id = 0 // TaskSet priority for Stages run with "job1" as the ActiveJob + checkJobPropertiesAndPriority(taskSets(0), "job1", job1Id) + complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1)))) + + shuffleDep1 + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active + */ + test("stage used by two jobs, the first no longer active (SPARK-6880)") { + launchJobsThatShareStageAndCancelFirst() + + // The next check is the key for SPARK-6880. For the stage which was shared by both job1 and + // job2 but never had any tasks submitted for job1, the properties of job2 are now used to run + // the stage. + checkJobPropertiesAndPriority(taskSets(1), "job2", 1) + + complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1)))) + assert(taskSets(2).properties != null) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active, even when + * there are fetch failures + */ + test("stage used by two jobs, some fetch failures, and the first job no longer active " + + "(SPARK-6880)") { + val shuffleDep1 = launchJobsThatShareStageAndCancelFirst() + val job2Id = 1 // TaskSet priority for Stages run with "job2" as the ActiveJob + + // lets say there is a fetch failure in this task set, which makes us go back and + // run stage 0, attempt 1 + complete(taskSets(1), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // stage 0, attempt 1 should have the properties of job2 + assert(taskSets(2).stageId === 0) + assert(taskSets(2).stageAttemptId === 1) + checkJobPropertiesAndPriority(taskSets(2), "job2", job2Id) + + // run the rest of the stages normally, checking that they have the correct properties + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(3), "job2", job2Id) + complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(4), "job2", job2Id) + complete(taskSets(4), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) From c1f85fc71e71e07534b89c84677d977bb20994f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Nov 2015 09:47:20 -0800 Subject: [PATCH 1546/1824] [SPARK-11956][CORE] Fix a few bugs in network lib-based file transfer. - NettyRpcEnv::openStream() now correctly propagates errors to the read side of the pipe. - NettyStreamManager now throws if the file being transferred does not exist. - The network library now correctly handles zero-sized streams. Author: Marcelo Vanzin Closes #9941 from vanzin/SPARK-11956. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 19 +++++++++---- .../spark/rpc/netty/NettyStreamManager.scala | 2 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 27 +++++++++++++----- .../client/TransportResponseHandler.java | 28 ++++++++++++------- .../org/apache/spark/network/StreamSuite.java | 23 ++++++++++++++- 5 files changed, 75 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 68701f609f77..c8fa870f50e6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -27,7 +27,7 @@ import javax.annotation.Nullable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import scala.util.{DynamicVariable, Failure, Success} +import scala.util.{DynamicVariable, Failure, Success, Try} import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -368,13 +368,22 @@ private[netty] class NettyRpcEnv( @volatile private var error: Throwable = _ - def setError(e: Throwable): Unit = error = e + def setError(e: Throwable): Unit = { + error = e + source.close() + } override def read(dst: ByteBuffer): Int = { - if (error != null) { - throw error + val result = if (error == null) { + Try(source.read(dst)) + } else { + Failure(error) + } + + result match { + case Success(bytesRead) => bytesRead + case Failure(error) => throw error } - source.read(dst) } override def close(): Unit = source.close() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index eb1d2604fb23..a2768b4252dc 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -44,7 +44,7 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") } - require(file != null, s"File not found: $streamId") + require(file != null && file.isFile(), s"File not found: $streamId") new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2b664c6313ef..6cc958a5f6bc 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -729,23 +729,36 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val tempDir = Utils.createTempDir() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) + val empty = new File(tempDir, "empty") + Files.write("", empty, UTF_8); val jar = new File(tempDir, "jar") Files.write(UUID.randomUUID().toString(), jar, UTF_8) val fileUri = env.fileServer.addFile(file) + val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) val destDir = Utils.createTempDir() - val destFile = new File(destDir, file.getName()) - val destJar = new File(destDir, jar.getName()) - val sm = new SecurityManager(conf) val hc = SparkHadoopUtil.get.conf - Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) - Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) - assert(Files.equal(file, destFile)) - assert(Files.equal(jar, destJar)) + val files = Seq( + (file, fileUri), + (empty, emptyUri), + (jar, jarUri)) + files.foreach { case (f, uri) => + val destFile = new File(destDir, f.getName()) + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + assert(Files.equal(f, destFile)) + } + + // Try to download files that do not exist. + Seq("files", "jars").foreach { root => + intercept[Exception] { + val uri = env.address.toSparkURL + s"/$root/doesNotExist" + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + } + } } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index be181e066082..4c15045363b8 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -185,16 +185,24 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); - try { - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - frameDecoder.setInterceptor(interceptor); - streamActive = true; - } catch (Exception e) { - logger.error("Error installing stream handler.", e); - deactivateStream(); + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, + callback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } } } else { logger.error("Could not find callback for StreamResponse."); diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 00158fd08162..538f3efe8d6f 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -51,13 +51,14 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" }; + private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; private static TransportServer server; private static TransportClientFactory clientFactory; private static File testFile; private static File tempDir; + private static ByteBuffer emptyBuffer; private static ByteBuffer smallBuffer; private static ByteBuffer largeBuffer; @@ -73,6 +74,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); smallBuffer = createBuffer(100); largeBuffer = createBuffer(100000); @@ -103,6 +105,8 @@ public ManagedBuffer openStream(String streamId) { return new NioManagedBuffer(largeBuffer); case "smallBuffer": return new NioManagedBuffer(smallBuffer); + case "emptyBuffer": + return new NioManagedBuffer(emptyBuffer); case "file": return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); default: @@ -138,6 +142,18 @@ public static void tearDown() { } } + @Test + public void testZeroLengthStream() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } finally { + client.close(); + } + } + @Test public void testSingleStream() throws Throwable { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); @@ -226,6 +242,11 @@ public void run() { outFile = File.createTempFile("data", ".tmp", tempDir); out = new FileOutputStream(outFile); break; + case "emptyBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = emptyBuffer; + break; default: throw new IllegalArgumentException(streamId); } From faabdfa2bd416ae514961535f1953e8e9e8b1f3f Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 Nov 2015 10:36:35 -0800 Subject: [PATCH 1547/1824] [SPARK-11984][SQL][PYTHON] Fix typos in doc for pivot for scala and python Author: felixcheung Closes #9967 from felixcheung/pypivotdoc. --- python/pyspark/sql/group.py | 6 +++--- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index d8ed7eb2dda6..1911588309af 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -169,11 +169,11 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): - """Pivots a column of the current DataFrame and preform the specified aggregation. + """Pivots a column of the current DataFrame and perform the specified aggregation. :param pivot_col: Column to pivot - :param values: Optional list of values of pivotColumn that will be translated to columns in - the output data frame. If values are not provided the method with do an immediate call + :param values: Optional list of values of pivot column that will be translated to columns in + the output DataFrame. If values are not provided the method will do an immediate call to .distinct() on the pivot column. >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index abd531c4ba54..13341a88a6b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -282,7 +282,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -321,7 +321,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -353,7 +353,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. From 6b781576a15d8d5c5fbed8bef1c5bda95b3d44ac Mon Sep 17 00:00:00 2001 From: Zhongshuai Pei Date: Wed, 25 Nov 2015 10:37:34 -0800 Subject: [PATCH 1548/1824] [SPARK-11974][CORE] Not all the temp dirs had been deleted when the JVM exits deleting the temp dir like that ``` scala> import scala.collection.mutable import scala.collection.mutable scala> val a = mutable.Set(1,2,3,4,7,0,8,98,9) a: scala.collection.mutable.Set[Int] = Set(0, 9, 1, 2, 3, 7, 4, 8, 98) scala> a.foreach(x => {a.remove(x) }) scala> a.foreach(println(_)) 98 ``` You may not modify a collection while traversing or iterating over it.This can not delete all element of the collection Author: Zhongshuai Pei Closes #9951 from DoingDone9/Bug_RemainDir. --- .../scala/org/apache/spark/util/ShutdownHookManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index db4a8b304ec3..4012dca3ecdf 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -57,7 +57,9 @@ private[spark] object ShutdownHookManager extends Logging { // Add a shutdown hook to delete the temp dirs when the JVM exits addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => + // we need to materialize the paths to delete because deleteRecursively removes items from + // shutdownDeletePaths as we are traversing through it. + shutdownDeletePaths.toArray.foreach { dirPath => try { logInfo("Deleting directory " + dirPath) Utils.deleteRecursively(new File(dirPath)) From dc1d324fdf83e9f4b1debfb277533b002691d71f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 11:11:39 -0800 Subject: [PATCH 1549/1824] [SPARK-11969] [SQL] [PYSPARK] visualization of SQL query for pyspark Currently, we does not have visualization for SQL query from Python, this PR fix that. cc zsxwing Author: Davies Liu Closes #9949 from davies/pyspark_sql_ui. --- python/pyspark/sql/dataframe.py | 2 +- .../main/scala/org/apache/spark/sql/DataFrame.scala | 7 +++++++ .../org/apache/spark/sql/execution/python.scala | 12 +++++++----- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0dd75ba7ca82..746bb55e14f2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -277,7 +277,7 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d8319b9a97fc..6197f10813a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -29,6 +29,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -1735,6 +1736,12 @@ class DataFrame private[sql]( EvaluatePython.javaToPython(rdd) } + protected[sql] def collectToPython(): Int = { + withNewExecutionId { + PythonRDD.collectAndServe(javaToPython.rdd) + } + } + //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// // Deprecated methods diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index d611b0011da1..defcec95fb55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -121,11 +121,13 @@ object EvaluatePython { def takeAndServe(df: DataFrame, n: Int): Int = { registerPicklers() - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) - PythonRDD.serveIterator(iter, s"serve-DataFrame") + df.withNewExecutionId { + val iter = new SerDeUtil.AutoBatchedPickler( + df.queryExecution.executedPlan.executeTake(n).iterator.map { row => + EvaluatePython.toJava(row, df.schema) + }) + PythonRDD.serveIterator(iter, s"serve-DataFrame") + } } /** From 0dee44a6646daae0cc03dbc32125e080dff0f4ae Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 25 Nov 2015 11:35:52 -0800 Subject: [PATCH 1550/1824] [MINOR] Remove unnecessary spaces in `include_example.rb` Author: Yu ISHIKAWA Closes #9960 from yu-iskw/minor-remove-spaces. --- docs/_plugins/include_example.rb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 549f81fe1b1b..564c86680f68 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -20,12 +20,12 @@ module Jekyll class IncludeExampleTag < Liquid::Tag - + def initialize(tag_name, markup, tokens) @markup = markup super end - + def render(context) site = context.registers[:site] config_dir = '../examples/src/main' @@ -37,7 +37,7 @@ def render(context) code = File.open(@file).read.encode("UTF-8") code = select_lines(code) - + rendered_code = Pygments.highlight(code, :lexer => @lang) hint = "
    Find full example code at " \ @@ -45,7 +45,7 @@ def render(context) rendered_code + hint end - + # Trim the code block so as to have the same indention, regardless of their positions in the # code file. def trim_codeblock(lines) From 67b67320884282ccf3102e2af96f877e9b186517 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 25 Nov 2015 11:37:42 -0800 Subject: [PATCH 1551/1824] [DOCUMENTATION] Fix minor doc error Author: Jeff Zhang Closes #9956 from zjffdu/dev_typo. --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 4de202d7f763..741d6b2b37a8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,7 +35,7 @@ val sc = new SparkContext(conf) {% endhighlight %} Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may -actually require one to prevent any sort of starvation issues. +actually require more than 1 thread to prevent any sort of starvation issues. Properties that specify some time duration should be configured with a unit of time. The following format is accepted: From 83653ac5e71996c5a366a42170bed316b208f1b5 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Wed, 25 Nov 2015 11:39:00 -0800 Subject: [PATCH 1552/1824] [SPARK-10864][WEB UI] app name is hidden if window is resized Currently the Web UI navbar has a minimum width of 1200px; so if a window is resized smaller than that the app name goes off screen. The 1200px width seems to have been chosen since it fits the longest example app name without wrapping. To work with smaller window widths I made the tabs wrap since it looked better than wrapping the app name. This is a distinct change in how the navbar looks and I'm not sure if it's what we actually want to do. Other notes: - min-width set to 600px to keep the tabs from wrapping individually (will need to be adjusted if tabs are added) - app name will also wrap (making three levels) if a really really long app name is used Author: Alex Bozarth Closes #9874 from ajbozarth/spark10864. --- .../main/resources/org/apache/spark/ui/static/webui.css | 8 ++------ core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 04f3070d25b4..c628a0c70655 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -16,14 +16,9 @@ */ .navbar { - height: 50px; font-size: 15px; margin-bottom: 15px; - min-width: 1200px -} - -.navbar .navbar-inner { - height: 50px; + min-width: 600px; } .navbar .brand { @@ -46,6 +41,7 @@ .navbar-text { height: 50px; line-height: 3.3; + white-space: nowrap; } table.sortable thead { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 84a1116a5c49..1e8194f57888 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -210,10 +210,10 @@ private[spark] object UIUtils extends Logging { {org.apache.spark.SPARK_VERSION}
    - +
    From 9f3e59a16822fb61d60cf103bd4f7823552939c6 Mon Sep 17 00:00:00 2001 From: wangt Date: Wed, 25 Nov 2015 11:41:05 -0800 Subject: [PATCH 1553/1824] [SPARK-11880][WINDOWS][SPARK SUBMIT] bin/load-spark-env.cmd loads spark-env.cmd from wrong directory * On windows the `bin/load-spark-env.cmd` tries to load `spark-env.cmd` from `%~dp0..\..\conf`, where `~dp0` points to `bin` and `conf` is only one level up. * Updated `bin/load-spark-env.cmd` to load `spark-env.cmd` from `%~dp0..\conf`, instead of `%~dp0..\..\conf` Author: wangt Closes #9863 from toddwan/master. --- bin/load-spark-env.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index 36d932c453b6..59080edd294f 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -27,7 +27,7 @@ if [%SPARK_ENV_LOADED%] == [] ( if not [%SPARK_CONF_DIR%] == [] ( set user_conf_dir=%SPARK_CONF_DIR% ) else ( - set user_conf_dir=%~dp0..\..\conf + set user_conf_dir=%~dp0..\conf ) call :LoadSparkEnv From 88875d9413ec7d64a88d40857ffcf97b5853a7f2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 25 Nov 2015 11:42:53 -0800 Subject: [PATCH 1554/1824] [SPARK-10558][CORE] Fix wrong executor state in Master `ExecutorAdded` can only be sent to `AppClient` when worker report back the executor state as `LOADING`, otherwise because of concurrency issue, `AppClient` will possibly receive `ExectuorAdded` at first, then `ExecutorStateUpdated` with `LOADING` state. Also Master will change the executor state from `LAUNCHING` to `RUNNING` (`AppClient` report back the state as `RUNNING`), then to `LOADING` (worker report back to state as `LOADING`), it should be `LAUNCHING` -> `LOADING` -> `RUNNING`. Also it is wrongly shown in master UI, the state of executor should be `RUNNING` rather than `LOADING`: ![screen shot 2015-09-11 at 2 30 28 pm](https://cloud.githubusercontent.com/assets/850797/9809254/3155d840-5899-11e5-8cdf-ad06fef75762.png) Author: jerryshao Closes #8714 from jerryshao/SPARK-10558. --- .../org/apache/spark/deploy/ExecutorState.scala | 2 +- .../org/apache/spark/deploy/client/AppClient.scala | 3 --- .../org/apache/spark/deploy/master/Master.scala | 14 +++++++++++--- .../org/apache/spark/deploy/worker/Worker.scala | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index efa88c62e1f5..69c98e28931d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy private[deploy] object ExecutorState extends Enumeration { - val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST, EXITED = Value + val LAUNCHING, RUNNING, KILLED, FAILED, LOST, EXITED = Value type ExecutorState = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index afab362e213b..df6ba7d669ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -178,9 +178,6 @@ private[spark] class AppClient( val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not - // guaranteed), `ExecutorStateChanged` may be sent to a dead master. - sendToMaster(ExecutorStateChanged(appId.get, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b25a487806c7..9952c97dbdff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -253,9 +253,17 @@ private[deploy] class Master( execOption match { case Some(exec) => { val appInfo = idToApp(appId) + val oldState = exec.state exec.state = state - if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } + + if (state == ExecutorState.RUNNING) { + assert(oldState == ExecutorState.LAUNCHING, + s"executor $execId state transfer from $oldState to RUNNING is illegal") + appInfo.resetRetryCount() + } + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) + if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -702,8 +710,8 @@ private[deploy] class Master( worker.addExecutor(exec) worker.endpoint.send(LaunchExecutor(masterUrl, exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) - exec.application.driver.send(ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) + exec.application.driver.send( + ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index a45867e7680e..418faf8fc967 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -469,7 +469,7 @@ private[deploy] class Worker( executorDir, workerUri, conf, - appLocalDirs, ExecutorState.LOADING) + appLocalDirs, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ From d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 25 Nov 2015 11:47:21 -0800 Subject: [PATCH 1555/1824] [SPARK-11935][PYSPARK] Send the Python exceptions in TransformFunction and TransformFunctionSerializer to Java The Python exception track in TransformFunction and TransformFunctionSerializer is not sent back to Java. Py4j just throws a very general exception, which is hard to debug. This PRs adds `getFailure` method to get the failure message in Java side. Author: Shixiong Zhu Closes #9922 from zsxwing/SPARK-11935. --- python/pyspark/streaming/tests.py | 82 ++++++++++++++++++- python/pyspark/streaming/util.py | 29 +++++-- .../streaming/api/python/PythonDStream.scala | 52 ++++++++++-- 3 files changed, 144 insertions(+), 19 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a0e0267cafa5..d380d697bc51 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -404,17 +404,69 @@ def func(dstream): self._test_func(input, func, expected) def test_failed_func(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], time: Time) input = [self.sc.parallelize([d], 1) for d in range(4)] input_stream = self.ssc.queueStream(input) def failed_func(i): - raise ValueError("failed") + raise ValueError("This is a special error") input_stream.map(failed_func).pprint() self.ssc.start() try: self.ssc.awaitTerminationOrTimeout(10) except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func2(self): + # Test failure in + # TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time) + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream1 = self.ssc.queueStream(input) + input_stream2 = self.ssc.queueStream(input) + + def failed_func(rdd1, rdd2): + raise ValueError("This is a special error") + + input_stream1.transformWith(failed_func, input_stream2, True).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) + return + + self.fail("a failed func should throw an error") + + def test_failed_func_with_reseting_failure(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + if i == 1: + # Make it fail in the second batch + raise ValueError("This is a special error") + else: + return i + + # We should be able to see the results of the 3rd and 4th batches even if the second batch + # fails + expected = [[0], [2], [3]] + self.assertEqual(expected, self._collect(input_stream.map(failed_func), 3)) + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + import traceback + failure = traceback.format_exc() + self.assertTrue("This is a special error" in failure) return self.fail("a failed func should throw an error") @@ -780,6 +832,34 @@ def tearDown(self): if self.cpd is not None: shutil.rmtree(self.cpd) + def test_transform_function_serializer_failure(self): + inputd = tempfile.mkdtemp() + self.cpd = tempfile.mkdtemp("test_transform_function_serializer_failure") + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + + # A function that cannot be serialized + def process(time, rdd): + sc.parallelize(range(1, 10)) + + ssc.textFileStream(inputd).foreachRDD(process) + return ssc + + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + try: + self.ssc.start() + except: + import traceback + failure = traceback.format_exc() + self.assertTrue( + "It appears that you are attempting to reference SparkContext" in failure) + return + + self.fail("using SparkContext in process should fail because it's not Serializable") + def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() outputd = tempfile.mkdtemp() + "/" diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 767c732eb90b..c7f02bca2ae3 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -38,12 +38,15 @@ def __init__(self, ctx, func, *deserializers): self.func = func self.deserializers = deserializers self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.failure = None def rdd_wrapper(self, func): self._rdd_wrapper = func return self def call(self, milliseconds, jrdds): + # Clear the failure + self.failure = None try: if self.ctx is None: self.ctx = SparkContext._active_spark_context @@ -62,9 +65,11 @@ def call(self, milliseconds, jrdds): r = self.func(t, *rdds) if r: return r._jrdd - except Exception: - traceback.print_exc() - raise + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunction(%s)" % self.func @@ -89,22 +94,28 @@ def __init__(self, ctx, serializer, gateway=None): self.serializer = serializer self.gateway = gateway or self.ctx._gateway self.gateway.jvm.PythonDStream.registerSerializer(self) + self.failure = None def dumps(self, id): + # Clear the failure + self.failure = None try: func = self.gateway.gateway_property.pool[id] return bytearray(self.serializer.dumps((func.func, func.deserializers))) - except Exception: - traceback.print_exc() - raise + except: + self.failure = traceback.format_exc() def loads(self, data): + # Clear the failure + self.failure = None try: f, deserializers = self.serializer.loads(bytes(data)) return TransformFunction(self.ctx, f, *deserializers) - except Exception: - traceback.print_exc() - raise + except: + self.failure = traceback.format_exc() + + def getLastFailure(self): + return self.failure def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index dfc569451df8..994309ddd0a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -26,6 +26,7 @@ import scala.language.existentials import py4j.GatewayServer +import org.apache.spark.SparkException import org.apache.spark.api.java._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -40,6 +41,13 @@ import org.apache.spark.util.Utils */ private[python] trait PythonTransformFunction { def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] + + /** + * Get the failure, if any, in the last call to `call`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -48,6 +56,13 @@ private[python] trait PythonTransformFunction { private[python] trait PythonTransformFunctionSerializer { def dumps(id: String): Array[Byte] def loads(bytes: Array[Byte]): PythonTransformFunction + + /** + * Get the failure, if any, in the last call to `dumps` or `loads`. + * + * @return the failure message if there was a failure, or `null` if there was no failure. + */ + def getLastFailure: String } /** @@ -59,18 +74,27 @@ private[python] class TransformFunction(@transient var pfunc: PythonTransformFun extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) - .map(_.rdd) + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava - Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd) } // for function.Function2 def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { - pfunc.call(time.milliseconds, rdds) + callPythonTransformFunction(time.milliseconds, rdds) + } + + private def callPythonTransformFunction(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] = { + val resultRDD = pfunc.call(time, rdds) + val failure = pfunc.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + resultRDD } private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -103,23 +127,33 @@ private[python] object PythonTransformFunctionSerializer { /* * Register a serializer from Python, should be called during initialization */ - def register(ser: PythonTransformFunctionSerializer): Unit = { + def register(ser: PythonTransformFunctionSerializer): Unit = synchronized { serializer = ser } - def serialize(func: PythonTransformFunction): Array[Byte] = { + def serialize(func: PythonTransformFunction): Array[Byte] = synchronized { require(serializer != null, "Serializer has not been registered!") // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") f.setAccessible(true) val id = f.get(h).asInstanceOf[String] - serializer.dumps(id) + val results = serializer.dumps(id) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + results } - def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + def deserialize(bytes: Array[Byte]): PythonTransformFunction = synchronized { require(serializer != null, "Serializer has not been registered!") - serializer.loads(bytes) + val pfunc = serializer.loads(bytes) + val failure = serializer.getLastFailure + if (failure != null) { + throw new SparkException("An exception was raised by Python:\n" + failure) + } + pfunc } } From 4e81783e92f464d479baaf93eccc3adb1496989a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Nov 2015 12:58:18 -0800 Subject: [PATCH 1556/1824] [SPARK-11866][NETWORK][CORE] Make sure timed out RPCs are cleaned up. This change does a couple of different things to make sure that the RpcEnv-level code and the network library agree about the status of outstanding RPCs. For RPCs that do not expect a reply ("RpcEnv.send"), support for one way messages (hello CORBA!) was added to the network layer. This is a "fire and forget" message that does not require any state to be kept by the TransportClient; as a result, the RpcEnv 'Ack' message is not needed anymore. For RPCs that do expect a reply ("RpcEnv.ask"), the network library now returns the internal RPC id; if the RpcEnv layer decides to time out the RPC before the network layer does, it now asks the TransportClient to forget about the RPC, so that if the network-level timeout occurs, the client is not killed. As part of implementing the above, I cleaned up some of the code in the netty rpc backend, removing types that were not necessary and factoring out some common code. Of interest is a slight change in the exceptions when posting messages to a stopped RpcEnv; that's mostly to avoid nasty error messages from the local-cluster backend when shutting down, which pollutes the terminal output. Author: Marcelo Vanzin Closes #9917 from vanzin/SPARK-11866. --- .../spark/deploy/worker/ExecutorRunner.scala | 6 +- .../apache/spark/rpc/netty/Dispatcher.scala | 55 +++---- .../org/apache/spark/rpc/netty/Inbox.scala | 28 ++-- .../spark/rpc/netty/NettyRpcCallContext.scala | 35 +--- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 153 +++++++----------- .../org/apache/spark/rpc/netty/Outbox.scala | 64 ++++++-- .../apache/spark/rpc/netty/InboxSuite.scala | 6 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 2 +- .../spark/network/client/TransportClient.java | 34 +++- .../spark/network/protocol/Message.java | 4 +- .../network/protocol/MessageDecoder.java | 3 + .../spark/network/protocol/OneWayMessage.java | 75 +++++++++ .../spark/network/sasl/SaslRpcHandler.java | 5 + .../spark/network/server/RpcHandler.java | 36 +++++ .../server/TransportRequestHandler.java | 18 ++- .../apache/spark/network/ProtocolSuite.java | 2 + .../spark/network/RpcIntegrationSuite.java | 31 ++++ .../spark/network/sasl/SparkSaslSuite.java | 9 ++ 18 files changed, 374 insertions(+), 192 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 3aef0515cbf6..25a17473e4b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -92,7 +92,11 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) + try { + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) + } catch { + case e: IllegalStateException => logWarning(e.getMessage(), e) + } } /** Stop this executor runner, including killing the process it launched */ diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index eb25d6c7b721..533c9847661b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -106,44 +106,30 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val iter = endpoints.keySet().iterator() while (iter.hasNext) { val name = iter.next - postMessage( - name, - _ => message, - () => { logWarning(s"Drop $message because $name has been stopped") }) + postMessage(name, message, (e) => logWarning(s"Message $message dropped.", e)) } } /** Posts a message sent by a remote endpoint. */ def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new RemoteNettyRpcCallContext( - nettyEnv, sender, callback, message.senderAddress, message.needReply) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - callback.onFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } - - postMessage(message.receiver.name, createMessage, onEndpointStopped) + val rpcCallContext = + new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e)) } /** Posts a message sent by a local endpoint. */ def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = { - def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { - val rpcCallContext = - new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) - } - - def onEndpointStopped(): Unit = { - p.tryFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } + val rpcCallContext = + new LocalNettyRpcCallContext(message.senderAddress, p) + val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) + postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e)) + } - postMessage(message.receiver.name, createMessage, onEndpointStopped) + /** Posts a one-way message. */ + def postOneWayMessage(message: RequestMessage): Unit = { + postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content), + (e) => throw e) } /** @@ -155,21 +141,26 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { */ private def postMessage( endpointName: String, - createMessageFn: NettyRpcEndpointRef => InboxMessage, - callbackIfStopped: () => Unit): Unit = { + message: InboxMessage, + callbackIfStopped: (Exception) => Unit): Unit = { val shouldCallOnStop = synchronized { val data = endpoints.get(endpointName) if (stopped || data == null) { true } else { - data.inbox.post(createMessageFn(data.ref)) + data.inbox.post(message) receivers.offer(data) false } } if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block - callbackIfStopped() + val error = if (stopped) { + new IllegalStateException("RpcEnv already stopped.") + } else { + new SparkException(s"Could not find $endpointName or it has been stopped.") + } + callbackIfStopped(error) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 464027f07cc8..175463cc1031 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -27,10 +27,13 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} private[netty] sealed trait InboxMessage -private[netty] case class ContentMessage( +private[netty] case class OneWayMessage( + senderAddress: RpcAddress, + content: Any) extends InboxMessage + +private[netty] case class RpcMessage( senderAddress: RpcAddress, content: Any, - needReply: Boolean, context: NettyRpcCallContext) extends InboxMessage private[netty] case object OnStart extends InboxMessage @@ -96,29 +99,24 @@ private[netty] class Inbox( while (true) { safelyCall(endpoint) { message match { - case ContentMessage(_sender, content, needReply, context) => - // The partial function to call - val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive + case RpcMessage(_sender, content, context) => try { - pf.applyOrElse[Any, Unit](content, { msg => + endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg => throw new SparkException(s"Unsupported message $message from ${_sender}") }) - if (!needReply) { - context.finish() - } } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - context.sendFailure(e) - } else { - context.finish() - } + context.sendFailure(e) // Throw the exception -- this exception will be caught by the safelyCall function. // The endpoint's onError function will be called. throw e } + case OneWayMessage(_sender, content) => + endpoint.receive.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unsupported message $message from ${_sender}") + }) + case OnStart => endpoint.onStart() if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 21d5bb4923d1..6637e2321f67 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -23,49 +23,28 @@ import org.apache.spark.Logging import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc.{RpcAddress, RpcCallContext} -private[netty] abstract class NettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, - override val senderAddress: RpcAddress, - needReply: Boolean) +private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress) extends RpcCallContext with Logging { protected def send(message: Any): Unit override def reply(response: Any): Unit = { - if (needReply) { - send(AskResponse(endpointRef, response)) - } else { - throw new IllegalStateException( - s"Cannot send $response to the sender because the sender does not expect a reply") - } + send(response) } override def sendFailure(e: Throwable): Unit = { - if (needReply) { - send(AskResponse(endpointRef, RpcFailure(e))) - } else { - logError(e.getMessage, e) - throw new IllegalStateException( - "Cannot send reply to the sender because the sender won't handle it") - } + send(RpcFailure(e)) } - def finish(): Unit = { - if (!needReply) { - send(Ack(endpointRef)) - } - } } /** * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`. */ private[netty] class LocalNettyRpcCallContext( - endpointRef: NettyRpcEndpointRef, senderAddress: RpcAddress, - needReply: Boolean, p: Promise[Any]) - extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + extends NettyRpcCallContext(senderAddress) { override protected def send(message: Any): Unit = { p.success(message) @@ -77,11 +56,9 @@ private[netty] class LocalNettyRpcCallContext( */ private[netty] class RemoteNettyRpcCallContext( nettyEnv: NettyRpcEnv, - endpointRef: NettyRpcEndpointRef, callback: RpcResponseCallback, - senderAddress: RpcAddress, - needReply: Boolean) - extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + senderAddress: RpcAddress) + extends NettyRpcCallContext(senderAddress) { override protected def send(message: Any): Unit = { val reply = nettyEnv.serialize(message) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index c8fa870f50e6..c7d74fa1d919 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -150,7 +150,7 @@ private[netty] class NettyRpcEnv( private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = { if (receiver.client != null) { - receiver.client.sendRpc(message.content, message.createCallback(receiver.client)); + message.sendWith(receiver.client) } else { require(receiver.address != null, "Cannot send message to client endpoint with no listen address.") @@ -182,25 +182,10 @@ private[netty] class NettyRpcEnv( val remoteAddr = message.receiver.address if (remoteAddr == address) { // Message to a local RPC endpoint. - val promise = Promise[Any]() - dispatcher.postLocalMessage(message, promise) - promise.future.onComplete { - case Success(response) => - val ack = response.asInstanceOf[Ack] - logTrace(s"Received ack from ${ack.sender}") - case Failure(e) => - logWarning(s"Exception when sending $message", e) - }(ThreadUtils.sameThread) + dispatcher.postOneWayMessage(message) } else { // Message to a remote RPC endpoint. - postToOutbox(message.receiver, OutboxMessage(serialize(message), - (e) => { - logWarning(s"Exception when sending $message", e) - }, - (client, response) => { - val ack = deserialize[Ack](client, response) - logDebug(s"Receive ack from ${ack.sender}") - })) + postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message))) } } @@ -208,46 +193,52 @@ private[netty] class NettyRpcEnv( clientFactory.createClient(address.host, address.port) } - private[netty] def ask(message: RequestMessage): Future[Any] = { + private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = { val promise = Promise[Any]() val remoteAddr = message.receiver.address + + def onFailure(e: Throwable): Unit = { + if (!promise.tryFailure(e)) { + logWarning(s"Ignored failure: $e") + } + } + + def onSuccess(reply: Any): Unit = reply match { + case RpcFailure(e) => onFailure(e) + case rpcReply => + if (!promise.trySuccess(rpcReply)) { + logWarning(s"Ignored message: $reply") + } + } + if (remoteAddr == address) { val p = Promise[Any]() - dispatcher.postLocalMessage(message, p) p.future.onComplete { - case Success(response) => - val reply = response.asInstanceOf[AskResponse] - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - case Failure(e) => - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } + case Success(response) => onSuccess(response) + case Failure(e) => onFailure(e) }(ThreadUtils.sameThread) + dispatcher.postLocalMessage(message, p) } else { - postToOutbox(message.receiver, OutboxMessage(serialize(message), - (e) => { - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } - }, - (client, response) => { - val reply = deserialize[AskResponse](client, response) - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - })) + val rpcMessage = RpcOutboxMessage(serialize(message), + onFailure, + (client, response) => onSuccess(deserialize[Any](client, response))) + postToOutbox(message.receiver, rpcMessage) + promise.future.onFailure { + case _: TimeoutException => rpcMessage.onTimeout() + case _ => + }(ThreadUtils.sameThread) } - promise.future + + val timeoutCancelable = timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + promise.tryFailure( + new TimeoutException("Cannot receive any reply in ${timeout.duration}")) + } + }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) + promise.future.onComplete { v => + timeoutCancelable.cancel(true) + }(ThreadUtils.sameThread) + promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } private[netty] def serialize(content: Any): Array[Byte] = { @@ -512,25 +503,12 @@ private[netty] class NettyRpcEndpointRef( override def name: String = _name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - val promise = Promise[Any]() - val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable { - override def run(): Unit = { - promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration)) - } - }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) - val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true)) - f.onComplete { v => - timeoutCancelable.cancel(true) - if (!promise.tryComplete(v)) { - logWarning(s"Ignore message $v") - } - }(ThreadUtils.sameThread) - promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) + nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout) } override def send(message: Any): Unit = { require(message != null, "Message is null") - nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false)) + nettyEnv.send(RequestMessage(nettyEnv.address, this, message)) } override def toString: String = s"NettyRpcEndpointRef(${_address})" @@ -549,24 +527,7 @@ private[netty] class NettyRpcEndpointRef( * The message that is sent from the sender to the receiver. */ private[netty] case class RequestMessage( - senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean) - -/** - * The base trait for all messages that are sent back from the receiver to the sender. - */ -private[netty] trait ResponseMessage - -/** - * The reply for `ask` from the receiver side. - */ -private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any) - extends ResponseMessage - -/** - * A message to send back to the receiver side. It's necessary because [[TransportClient]] only - * clean the resources when it receives a reply. - */ -private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage + senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any) /** * A response that indicates some failure happens in the receiver side. @@ -598,6 +559,18 @@ private[netty] class NettyRpcHandler( client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postRemoteMessage(messageToDispatch, callback) + } + + override def receive( + client: TransportClient, + message: Array[Byte]): Unit = { + val messageToDispatch = internalReceive(client, message) + dispatcher.postOneWayMessage(messageToDispatch) + } + + private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) @@ -605,14 +578,12 @@ private[netty] class NettyRpcHandler( dispatcher.postToAll(RemoteProcessConnected(clientAddr)) } val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) - val messageToDispatch = if (requestMessage.senderAddress == null) { - // Create a new message with the socket address of the client as the sender. - RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, - requestMessage.needReply) - } else { - requestMessage - } - dispatcher.postRemoteMessage(messageToDispatch, callback) + if (requestMessage.senderAddress == null) { + // Create a new message with the socket address of the client as the sender. + RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) + } else { + requestMessage + } } override def getStreamManager: StreamManager = streamManager diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 2f6817f2eb93..36fdd00bbc4c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -22,22 +22,56 @@ import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal -import org.apache.spark.SparkException +import org.apache.spark.{Logging, SparkException} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.rpc.RpcAddress -private[netty] case class OutboxMessage(content: Array[Byte], - _onFailure: (Throwable) => Unit, - _onSuccess: (TransportClient, Array[Byte]) => Unit) { +private[netty] sealed trait OutboxMessage { - def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() { - override def onFailure(e: Throwable): Unit = { - _onFailure(e) - } + def sendWith(client: TransportClient): Unit - override def onSuccess(response: Array[Byte]): Unit = { - _onSuccess(client, response) - } + def onFailure(e: Throwable): Unit + +} + +private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage + with Logging { + + override def sendWith(client: TransportClient): Unit = { + client.send(content) + } + + override def onFailure(e: Throwable): Unit = { + logWarning(s"Failed to send one-way RPC.", e) + } + +} + +private[netty] case class RpcOutboxMessage( + content: Array[Byte], + _onFailure: (Throwable) => Unit, + _onSuccess: (TransportClient, Array[Byte]) => Unit) + extends OutboxMessage with RpcResponseCallback { + + private var client: TransportClient = _ + private var requestId: Long = _ + + override def sendWith(client: TransportClient): Unit = { + this.client = client + this.requestId = client.sendRpc(content, this) + } + + def onTimeout(): Unit = { + require(client != null, "TransportClient has not yet been set.") + client.removeRpcRequest(requestId) + } + + override def onFailure(e: Throwable): Unit = { + _onFailure(e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + _onSuccess(client, response) } } @@ -82,7 +116,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { } } if (dropped) { - message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) } else { drainOutbox() } @@ -122,7 +156,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { try { val _client = synchronized { client } if (_client != null) { - _client.sendRpc(message.content, message.createCallback(_client)) + message.sendWith(_client) } else { assert(stopped == true) } @@ -195,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message._onFailure(e) + message.onFailure(e) message = messages.poll() } assert(messages.isEmpty) @@ -229,7 +263,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message.onFailure(new SparkException("Message is dropped because Outbox is stopped")) message = messages.poll() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 276c077b3d13..2136795b1881 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -35,7 +35,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", false, null) + val message = OneWayMessage(null, "hi") inbox.post(message) inbox.process(dispatcher) assert(inbox.isEmpty) @@ -55,7 +55,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val inbox = new Inbox(endpointRef, endpoint) - val message = ContentMessage(null, "hi", true, null) + val message = RpcMessage(null, "hi", null) inbox.post(message) inbox.process(dispatcher) assert(inbox.isEmpty) @@ -83,7 +83,7 @@ class InboxSuite extends SparkFunSuite { new Thread { override def run(): Unit = { for (_ <- 0 until 100) { - val message = ContentMessage(null, "hi", false, null) + val message = OneWayMessage(null, "hi") inbox.post(message) } exitLatch.countDown() diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index ccca795683da..323184cdd9b6 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -33,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) - .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 876fcd846791..8a58e7b24585 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; @@ -36,6 +37,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.StreamRequest; @@ -205,8 +207,12 @@ public void operationComplete(ChannelFuture future) throws Exception { /** * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked * with the server's response or upon any failure. + * + * @param message The message to send. + * @param callback Callback to handle the RPC's reply. + * @return The RPC's id. */ - public void sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(byte[] message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -235,6 +241,8 @@ public void operationComplete(ChannelFuture future) throws Exception { } } }); + + return requestId; } /** @@ -265,11 +273,35 @@ public void onFailure(Throwable e) { } } + /** + * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the + * message, and no delivery guarantees are made. + * + * @param message The message to send. + */ + public void send(byte[] message) { + channel.writeAndFlush(new OneWayMessage(message)); + } + + /** + * Removes any state associated with the given RPC. + * + * @param requestId The RPC id returned by {@link #sendRpc(byte[], RpcResponseCallback)}. + */ + public void removeRpcRequest(long requestId) { + handler.removeRpcRequest(requestId); + } + /** Mark this channel as having timed out. */ public void timeOut() { this.timedOut = true; } + @VisibleForTesting + public TransportResponseHandler getHandler() { + return handler; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index d01598c20f16..39afd03db60e 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -28,7 +28,8 @@ public interface Message extends Encodable { public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), - StreamRequest(6), StreamResponse(7), StreamFailure(8); + StreamRequest(6), StreamResponse(7), StreamFailure(8), + OneWayMessage(9); private final byte id; @@ -55,6 +56,7 @@ public static Type decode(ByteBuf buf) { case 6: return StreamRequest; case 7: return StreamResponse; case 8: return StreamFailure; + case 9: return OneWayMessage; default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 3c04048f3821..074780f2b95c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -63,6 +63,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case RpcFailure: return RpcFailure.decode(in); + case OneWayMessage: + return OneWayMessage.decode(in); + case StreamRequest: return StreamRequest.decode(in); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java new file mode 100644 index 000000000000..95a0270be3da --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A RPC that does not expect a reply, which is handled by a remote + * {@link org.apache.spark.network.server.RpcHandler}. + */ +public final class OneWayMessage implements RequestMessage { + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public OneWayMessage(byte[] message) { + this.message = message; + } + + @Override + public Type type() { return Type.OneWayMessage; } + + @Override + public int encodedLength() { + return Encoders.ByteArrays.encodedLength(message); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.ByteArrays.encode(buf, message); + } + + public static OneWayMessage decode(ByteBuf buf) { + byte[] message = Encoders.ByteArrays.decode(buf); + return new OneWayMessage(message); + } + + @Override + public int hashCode() { + return Arrays.hashCode(message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof OneWayMessage) { + OneWayMessage o = (OneWayMessage) other; + return Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("message", message) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 7033adb9cae6..830db94b890c 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -108,6 +108,11 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, byte[] message) { + delegate.receive(client, message); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index dbb7f95f55bc..65109ddfe13b 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -17,6 +17,9 @@ package org.apache.spark.network.server; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -24,6 +27,9 @@ * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. */ public abstract class RpcHandler { + + private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback(); + /** * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. @@ -47,6 +53,19 @@ public abstract void receive( */ public abstract StreamManager getStreamManager(); + /** + * Receives an RPC message that does not expect a reply. The default implementation will + * call "{@link receive(TransportClient, byte[], RpcResponseCallback}" and log a warning if + * any of the callback methods are called. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param message The serialized bytes of the RPC. + */ + public void receive(TransportClient client, byte[] message) { + receive(client, message, ONE_WAY_CALLBACK); + } + /** * Invoked when the connection associated with the given client has been invalidated. * No further requests will come from this client. @@ -54,4 +73,21 @@ public abstract void receive( public void connectionTerminated(TransportClient client) { } public void exceptionCaught(Throwable cause, TransportClient client) { } + + private static class OneWayRpcCallback implements RpcResponseCallback { + + private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); + + @Override + public void onSuccess(byte[] response) { + logger.warn("Response provided for one-way RPC."); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Error response provided for one-way RPC.", e); + } + + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 4f67bd573be2..db18ea77d107 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -27,13 +28,14 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.OneWayMessage; +import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamFailure; import org.apache.spark.network.protocol.StreamRequest; @@ -95,6 +97,8 @@ public void handle(RequestMessage request) { processFetchRequest((ChunkFetchRequest) request); } else if (request instanceof RpcRequest) { processRpcRequest((RpcRequest) request); + } else if (request instanceof OneWayMessage) { + processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); } else { @@ -156,6 +160,14 @@ public void onFailure(Throwable e) { } } + private void processOneWayMessage(OneWayMessage req) { + try { + rpcHandler.receive(reverseClient, req.message); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } + } + /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 22b451fc0e60..1aa20900ffe7 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; @@ -84,6 +85,7 @@ public void requests() { testClientToServer(new RpcRequest(12345, new byte[0])); testClientToServer(new RpcRequest(12345, new byte[100])); testClientToServer(new StreamRequest("abcde")); + testClientToServer(new OneWayMessage(new byte[100])); } @Test diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8eb56bdd9846..88fa2258bb79 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,9 +17,11 @@ package org.apache.spark.network; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; @@ -46,6 +48,7 @@ public class RpcIntegrationSuite { static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; + static List oneWayMsgs; @BeforeClass public static void setUp() throws Exception { @@ -64,12 +67,19 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } } + @Override + public void receive(TransportClient client, byte[] message) { + String msg = new String(message, Charsets.UTF_8); + oneWayMsgs.add(msg); + } + @Override public StreamManager getStreamManager() { return new OneForOneStreamManager(); } }; TransportContext context = new TransportContext(conf, rpcHandler); server = context.createServer(); clientFactory = context.createClientFactory(); + oneWayMsgs = new ArrayList<>(); } @AfterClass @@ -158,6 +168,27 @@ public void sendSuccessAndFailure() throws Exception { assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); } + @Test + public void sendOneWayMessage() throws Exception { + final String message = "no reply"; + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.send(message.getBytes(Charsets.UTF_8)); + assertEquals(0, client.getHandler().numOutstandingRequests()); + + // Make sure the message arrives. + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) { + TimeUnit.MILLISECONDS.sleep(10); + } + + assertEquals(1, oneWayMsgs.size()); + assertEquals(message, oneWayMsgs.get(0)); + } finally { + client.close(); + } + } + private void assertErrorsContain(Set errors, Set contains) { assertEquals(contains.size(), errors.size()); diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index b14689967018..a6f180bc40c9 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.*; import java.io.File; +import java.lang.reflect.Method; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; @@ -353,6 +354,14 @@ public void testRpcHandlerDelegate() throws Exception { verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); } + @Test + public void testDelegates() throws Exception { + Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); + for (Method m : rpcHandlerMethods) { + SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes()); + } + } + private static class SaslTestCtx { final TransportClient client; From ecac2835458bbf73fe59413d5bf921500c5b987d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 25 Nov 2015 13:45:41 -0800 Subject: [PATCH 1557/1824] Fix Aggregator documentation (rename present to finish). --- .../scala/org/apache/spark/sql/expressions/Aggregator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index b0cd32b5f73e..65117d582475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} * def zero: Int = 0 * def reduce(b: Int, a: Data): Int = b + a.i * def merge(b1: Int, b2: Int): Int = b1 + b2 - * def present(r: Int): Int = r + * def finish(r: Int): Int = r * }.toColumn() * * val ds: Dataset[Data] = ... From 21e5606419c4b7462d30580c549e9bfa0123ae23 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 25 Nov 2015 13:51:30 -0800 Subject: [PATCH 1558/1824] [SPARK-11983][SQL] remove all unused codegen fallback trait Author: Daoyuan Wang Closes #9966 from adrian-wang/removeFallback. --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 3 +-- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 4 ++-- .../spark/sql/catalyst/expressions/NonFoldableLiteral.scala | 3 +-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 533d17ea5c17..a2c6c39fd8ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -104,8 +104,7 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) - extends UnaryExpression with CodegenFallback { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def toString: String = s"cast($child as ${dataType.simpleString})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 9e484c5ed83b..adef6050c356 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -66,7 +66,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { * Simple RegEx pattern matching function */ case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { + extends BinaryExpression with StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -117,7 +117,7 @@ case class Like(left: Expression, right: Expression) case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { + extends BinaryExpression with StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 31ecf4a9e810..118fd695fe2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql.types._ * A literal value that is not foldable. Used in expression codegen testing to test code path * that behave differently based on foldable values. */ -case class NonFoldableLiteral(value: Any, dataType: DataType) - extends LeafExpression with CodegenFallback { +case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = false override def nullable: Boolean = true From cc243a079b1c039d6e7f0b410d1654d94a090e14 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Wed, 25 Nov 2015 15:13:13 -0800 Subject: [PATCH 1559/1824] [SPARK-11206] Support SQL UI on the history server On the live web UI, there is a SQL tab which provides valuable information for the SQL query. But once the workload is finished, we won't see the SQL tab on the history server. It will be helpful if we support SQL UI on the history server so we can analyze it even after its execution. To support SQL UI on the history server: 1. I added an `onOtherEvent` method to the `SparkListener` trait and post all SQL related events to the same event bus. 2. Two SQL events `SparkListenerSQLExecutionStart` and `SparkListenerSQLExecutionEnd` are defined in the sql module. 3. The new SQL events are written to event log using Jackson. 4. A new trait `SparkHistoryListenerFactory` is added to allow the history server to feed events to the SQL history listener. The SQL implementation is loaded at runtime using `java.util.ServiceLoader`. Author: Carson Wang Closes #9297 from carsonwang/SqlHistoryUI. --- .rat-excludes | 1 + .../org/apache/spark/JavaSparkListener.java | 3 + .../apache/spark/SparkFirehoseListener.java | 4 + .../scheduler/EventLoggingListener.scala | 4 + .../spark/scheduler/SparkListener.scala | 24 ++- .../spark/scheduler/SparkListenerBus.scala | 1 + .../scala/org/apache/spark/ui/SparkUI.scala | 16 +- .../org/apache/spark/util/JsonProtocol.scala | 11 +- ...park.scheduler.SparkHistoryListenerFactory | 1 + .../org/apache/spark/sql/SQLContext.scala | 18 ++- .../spark/sql/execution/SQLExecution.scala | 24 +-- .../spark/sql/execution/SparkPlanInfo.scala | 46 ++++++ .../sql/execution/metric/SQLMetricInfo.scala | 30 ++++ .../sql/execution/metric/SQLMetrics.scala | 56 ++++--- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 139 ++++++++++++------ .../spark/sql/execution/ui/SQLTab.scala | 12 +- .../sql/execution/ui/SparkPlanGraph.scala | 20 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 43 +++--- .../spark/sql/test/SharedSQLContext.scala | 1 + 21 files changed, 327 insertions(+), 135 deletions(-) create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala diff --git a/.rat-excludes b/.rat-excludes index 08fba6d351d6..7262c960ed6b 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,4 +82,5 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index fa9acf0a15b8..23bc9a2e8172 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,4 +82,7 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + @Override + public void onOtherEvent(SparkListenerEvent event) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 1214d05ba606..e6b24afd88ad 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,4 +118,8 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 000a021a528c..eaa07acc5132 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,10 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + logEvent(event, flushLogger = true) + } + /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 896f1743332f..075a7f13172d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,15 +22,19 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import org.apache.spark.{Logging, TaskEndReason} +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.{Logging, SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.ui.SparkUI @DeveloperApi -sealed trait SparkListenerEvent +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -130,6 +134,17 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +/** + * Interface for creating history listeners defined in other modules like SQL, which are used to + * rebuild the history UI. + */ +private[spark] trait SparkHistoryListenerFactory { + /** + * Create listeners used to rebuild the history UI. + */ + def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] +} + /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -223,6 +238,11 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } + + /** + * Called when other events like SQL-specific events are posted. + */ + def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 04afde33f5aa..95722a07144e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,7 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata + case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 4608bce202ec..8da6884a3853 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,10 +17,13 @@ package org.apache.spark.ui -import java.util.Date +import java.util.{Date, ServiceLoader} + +import scala.collection.JavaConverters._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -154,7 +157,16 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + val sparkUI = create( + None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, sparkUI) + listeners.foreach(listenerBus.addListener) + } + sparkUI } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c9beeb25e05a..7f5d713ec650 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,19 +19,21 @@ package org.apache.spark.util import java.util.{Properties, UUID} -import org.apache.spark.scheduler.cluster.ExecutorInfo - import scala.collection.JavaConverters._ import scala.collection.Map +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -54,6 +56,8 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats + private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -96,6 +100,7 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this + case _ => parse(mapper.writeValueAsString(event)) } } @@ -511,6 +516,8 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) + .asInstanceOf[SparkListenerEvent] } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory new file mode 100644 index 000000000000..507100be9096 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 46bf544fd885..1c2ac5f6f11b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1263,6 +1263,8 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() + @transient private val sqlListener = new AtomicReference[SQLListener]() + /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1307,6 +1309,10 @@ object SQLContext { Option(instantiatedContext.get()) } + private[sql] def clearSqlListener(): Unit = { + sqlListener.set(null) + } + /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1355,9 +1361,13 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - val listener = new SQLListener(sc.conf) - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - listener + if (sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } + } + sqlListener.get() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 1422e15549c9..34971986261c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, + SparkListenerSQLExecutionEnd} import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -45,25 +46,14 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.listener.onExecutionStart( - executionId, - callSite.shortForm, - callSite.longForm, - queryExecution.toString, - SparkPlanGraph(queryExecution.executedPlan), - System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { body } finally { - // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. - // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new - // SQL event types to SparkListener since it's a public API, we cannot guarantee that. - // - // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. - // - // The worst case is onExecutionEnd may happen before onJobStart when the listener thread - // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. - sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala new file mode 100644 index 000000000000..486ce34064e4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.util.Utils + +/** + * :: DeveloperApi :: + * Stores information about a SQL SparkPlan. + */ +@DeveloperApi +class SparkPlanInfo( + val nodeName: String, + val simpleString: String, + val children: Seq[SparkPlanInfo], + val metrics: Seq[SQLMetricInfo]) + +private[sql] object SparkPlanInfo { + + def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, + Utils.getFormattedClassName(metric.param)) + } + val children = plan.children.map(fromSparkPlan) + + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala new file mode 100644 index 000000000000..2708219ad348 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Stores information about a SQL Metric. + */ +@DeveloperApi +class SQLMetricInfo( + val name: String, + val accumulatorId: Long, + val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1c253e3942e9..6c0f6f8a52dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,21 +104,39 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } +private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) + +private object StaticsLongSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + private[sql] object SQLMetrics { private def createLongMetric( sc: SparkContext, name: String, - stringValue: Seq[Long] => String, - initialValue: Long): LongSQLMetric = { - val param = new LongSQLMetricParam(stringValue, initialValue) + param: LongSQLMetricParam): LongSQLMetric = { val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, _.sum.toString, 0L) + createLongMetric(sc, name, LongSQLMetricParam) } /** @@ -126,31 +144,25 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { - val stringValue = (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) + createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) + } + + def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { + val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) + val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) + val metricParam = metricParamName match { + case `longSQLMetricParam` => LongSQLMetricParam + case `staticsSQLMetricParam` => StaticsLongSQLMetricParam + } + metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) + val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e74d6fb396e1..c74ad4040699 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} - -import org.apache.commons.lang3.StringEscapeUtils +import scala.xml.Node import org.apache.spark.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5a072de400b6..e19a1e3e5851 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,11 +19,34 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.ui.SparkUI + +@DeveloperApi +case class SparkListenerSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) + extends SparkListenerEvent + +private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { + + override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { + List(new SQLHistoryListener(conf, sparkUI)) + } +} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -118,7 +141,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), + finishTask = false) } } @@ -140,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics, + taskEnd.taskMetrics.accumulatorUpdates(), finishTask = true) } @@ -148,15 +172,12 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - private def updateTaskAccumulatorValues( + protected def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - metrics: TaskMetrics, + accumulatorUpdates: Map[Long, Any], finishTask: Boolean): Unit = { - if (metrics == null) { - return - } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -174,9 +195,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -185,7 +206,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + attemptId = 0, finished = finishTask, accumulatorUpdates) } } case None => @@ -193,38 +214,40 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - def onExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - physicalPlanGraph: SparkPlanGraph, - time: Long): Unit = { - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - - val executionUIData = new SQLExecutionUIData(executionId, description, details, - physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - } - - def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) => + val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + val executionUIData = new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + physicalPlanGraph, + sqlPlanMetrics.toMap, + time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } } } + case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -289,6 +312,38 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } +private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) + extends SQLListener(conf) { + + private var sqlTabAttached = false + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + // Do nothing + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskInfo.accumulables.map { acc => + (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) + }.toMap, + finishTask = true) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case _: SparkListenerSQLExecutionStart => + if (!sqlTabAttached) { + new SQLTab(this, sparkUI) + sqlTabAttached = true + } + super.onOtherEvent(event) + case _ => super.onOtherEvent(event) + } +} + /** * Represent all necessary data for an execution that will be used in Web UI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 9c27944d42fc..4f50b2ecdc8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.ui -import java.util.concurrent.atomic.AtomicInteger - import org.apache.spark.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI @@ -35,13 +33,5 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" - - private val nextTabId = new AtomicInteger(0) - - private def nextTabName: String = { - val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL$nextId" - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index f1fce5478a3f..7af0ff09c5c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.SQLMetrics /** * A graph used for storing information of an executionPlan of DataFrame. @@ -48,27 +48,27 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(plan: SparkPlan): SparkPlanGraph = { + def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - plan: SparkPlan, + planInfo: SparkPlanInfo, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - SQLPlanMetric(metric.name.getOrElse(key), metric.id, - metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) nodes += node - val childrenNodes = plan.children.map( + val childrenNodes = planInfo.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) for (child <- childrenNodes) { edges += SparkPlanGraphEdge(child.id, node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 5e2b4154dd7c..ebfa1eaf3e5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,6 +26,7 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -82,7 +83,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + df.queryExecution.executedPlan)).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index c15aac775096..f93d081d0c30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,10 +21,10 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.sql.test.SharedSQLContext class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { @@ -82,7 +82,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) + .nodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -90,13 +91,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) val executionUIData = listener.executionIdToData(0) @@ -206,7 +207,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -219,19 +221,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -248,13 +251,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -271,7 +274,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -288,19 +292,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 963d10eed62e..e7b376548787 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,6 +42,7 @@ trait SharedSQLContext extends SQLTestUtils { * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { + SQLContext.clearSqlListener() if (_ctx == null) { _ctx = new TestSQLContext } From d1930ec01ab5a9d83f801f8ae8d4f15a38d98b76 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 25 Nov 2015 21:25:20 -0800 Subject: [PATCH 1560/1824] [SPARK-12003] [SQL] remove the prefix for name after expanded star Right now, the expended start will include the name of expression as prefix for column, that's not better than without expending, we should not have the prefix. Author: Davies Liu Closes #9984 from davies/expand_star. --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 1b2a8dc4c7f1..4f89b462a6ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -204,7 +204,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu case s: StructType => s.zipWithIndex.map { case (f, i) => val extract = GetStructField(attribute.get, i) - Alias(extract, target.get + "." + f.name)() + Alias(extract, f.name)() } case _ => { From 068b6438d6886ce5b4aa698383866f466d913d66 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 25 Nov 2015 23:24:33 -0800 Subject: [PATCH 1561/1824] [SPARK-11980][SPARK-10621][SQL] Fix json_tuple and add test cases for Added Python test cases for the function `isnan`, `isnull`, `nanvl` and `json_tuple`. Fixed a bug in the function `json_tuple` rxin , could you help me review my changes? Please let me know anything is missing. Thank you! Have a good Thanksgiving day! Author: gatorsmile Closes #9977 from gatorsmile/json_tuple. --- python/pyspark/sql/functions.py | 44 +++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e3786e0fa5fb..90625949f747 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -286,14 +286,6 @@ def countDistinct(col, *cols): return Column(jc) -@since(1.4) -def monotonicallyIncreasingId(): - """ - .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. - """ - return monotonically_increasing_id() - - @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. @@ -305,6 +297,10 @@ def input_file_name(): @since(1.6) def isnan(col): """An expression that returns true iff the column is NaN. + + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.isnan(_to_java_column(col))) @@ -313,11 +309,23 @@ def isnan(col): @since(1.6) def isnull(col): """An expression that returns true iff the column is null. + + >>> df = sqlContext.createDataFrame([(1, None), (None, 2)], ("a", "b")) + >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect() + [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.isnull(_to_java_column(col))) +@since(1.4) +def monotonicallyIncreasingId(): + """ + .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. + """ + return monotonically_increasing_id() + + @since(1.6) def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. @@ -344,6 +352,10 @@ def nanvl(col1, col2): """Returns col1 if it is not NaN, or col2 if col1 is NaN. Both inputs should be floating point columns (DoubleType or FloatType). + + >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() + [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) @@ -1460,6 +1472,7 @@ def explode(col): return Column(jc) +@ignore_unicode_prefix @since(1.6) def get_json_object(col, path): """ @@ -1468,22 +1481,33 @@ def get_json_object(col, path): :param col: string column in json format :param path: path to the json object to extract + + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \ + get_json_object(df.jstring, '$.f2').alias("c1") ).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) return Column(jc) +@ignore_unicode_prefix @since(1.6) -def json_tuple(col, fields): +def json_tuple(col, *fields): """Creates a new row for a json column according to the given field names. :param col: string column in json format :param fields: list of fields to extract + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() + [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.json_tuple(_to_java_column(col), fields) + jc = sc._jvm.functions.json_tuple(_to_java_column(col), _to_seq(sc, fields)) return Column(jc) From d3ef693325f91a1ed340c9756c81244a80398eb2 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 25 Nov 2015 23:31:21 -0800 Subject: [PATCH 1562/1824] [SPARK-11999][CORE] Fix the issue that ThreadUtils.newDaemonCachedThreadPool doesn't cache any task In the previous codes, `newDaemonCachedThreadPool` uses `SynchronousQueue`, which is wrong. `SynchronousQueue` is an empty queue that cannot cache any task. This patch uses `LinkedBlockingQueue` to fix it along with other fixes to make sure `newDaemonCachedThreadPool` can use at most `maxThreadNumber` threads, and after that, cache tasks to `LinkedBlockingQueue`. Author: Shixiong Zhu Closes #9978 from zsxwing/cached-threadpool. --- .../org/apache/spark/util/ThreadUtils.scala | 14 ++++-- .../apache/spark/util/ThreadUtilsSuite.scala | 45 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 53283448c87b..f9fbe2ff858c 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -56,10 +56,18 @@ private[spark] object ThreadUtils { * Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names * are formatted as prefix-ID, where ID is a unique, sequentially assigned integer. */ - def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = { + def newDaemonCachedThreadPool( + prefix: String, maxThreadNumber: Int, keepAliveSeconds: Int = 60): ThreadPoolExecutor = { val threadFactory = namedThreadFactory(prefix) - new ThreadPoolExecutor( - 0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory) + val threadPool = new ThreadPoolExecutor( + maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks + maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used + keepAliveSeconds, + TimeUnit.SECONDS, + new LinkedBlockingQueue[Runnable], + threadFactory) + threadPool.allowCoreThreadTimeOut(true) + threadPool } /** diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 620e4debf4e0..92ae03896752 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -24,6 +24,8 @@ import scala.concurrent.duration._ import scala.concurrent.{Await, Future} import scala.util.Random +import org.scalatest.concurrent.Eventually._ + import org.apache.spark.SparkFunSuite class ThreadUtilsSuite extends SparkFunSuite { @@ -59,6 +61,49 @@ class ThreadUtilsSuite extends SparkFunSuite { } } + test("newDaemonCachedThreadPool") { + val maxThreadNumber = 10 + val startThreadsLatch = new CountDownLatch(maxThreadNumber) + val latch = new CountDownLatch(1) + val cachedThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "ThreadUtilsSuite-newDaemonCachedThreadPool", + maxThreadNumber, + keepAliveSeconds = 2) + try { + for (_ <- 1 to maxThreadNumber) { + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + startThreadsLatch.countDown() + latch.await(10, TimeUnit.SECONDS) + } + }) + } + startThreadsLatch.await(10, TimeUnit.SECONDS) + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 0) + + // Submit a new task and it should be put into the queue since the thread number reaches the + // limitation + cachedThreadPool.execute(new Runnable { + override def run(): Unit = { + latch.await(10, TimeUnit.SECONDS) + } + }) + + assert(cachedThreadPool.getActiveCount === maxThreadNumber) + assert(cachedThreadPool.getQueue.size === 1) + + latch.countDown() + eventually(timeout(10.seconds)) { + // All threads should be stopped after keepAliveSeconds + assert(cachedThreadPool.getActiveCount === 0) + assert(cachedThreadPool.getPoolSize === 0) + } + } finally { + cachedThreadPool.shutdownNow() + } + } + test("sameThread") { val callerThreadName = Thread.currentThread().getName() val f = Future { From 27d69a0573ed55e916a464e268dcfd5ecc6ed849 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 26 Nov 2015 00:19:42 -0800 Subject: [PATCH 1563/1824] [SPARK-11973] [SQL] push filter through aggregation with alias and literals Currently, filter can't be pushed through aggregation with alias or literals, this patch fix that. After this patch, the time of TPC-DS query 4 go down to 13 seconds from 141 seconds (10x improvements). cc nongli yhuai Author: Davies Liu Closes #9959 from davies/push_filter2. --- .../sql/catalyst/expressions/predicates.scala | 9 ++++ .../sql/catalyst/optimizer/Optimizer.scala | 28 ++++++---- .../optimizer/FilterPushdownSuite.scala | 53 +++++++++++++++++++ 3 files changed, 79 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 68557479a959..304b438c84ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -65,6 +65,15 @@ trait PredicateHelper { } } + // Substitute any known alias from a map. + protected def replaceAlias( + condition: Expression, + aliases: AttributeMap[Expression]): Expression = { + condition.transform { + case a: Attribute => aliases.getOrElse(a, a) + } + } + /** * Returns true if `expr` can be evaluated using only the output of `plan`. This method * can be used to determine when it is acceptable to move expression evaluation within a query diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f4dba67f13b5..52f609bc158c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -640,20 +640,14 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe filter } else { // Push down the small conditions without nondeterministic expressions. - val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And) + val pushedCondition = + deterministic.map(replaceAlias(_, aliasMap)).reduce(And) Filter(nondeterministic.reduce(And), project.copy(child = Filter(pushedCondition, grandChild))) } } } - // Substitute any attributes that are produced by the child projection, so that we safely - // eliminate it. - private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { - condition.transform { - case a: Attribute => sourceAliases.getOrElse(a, a) - } - } } /** @@ -690,12 +684,24 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf AttributeSet(groupingExpressions) + + def hasAggregate(expression: Expression): Boolean = expression match { + case agg: AggregateExpression => true + case other => expression.children.exists(hasAggregate) + } + // Create a map of Alias for expressions that does not have AggregateExpression + val aliasMap = AttributeMap(aggregateExpressions.collect { + case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child) + }) + + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct => + val replaced = replaceAlias(conjunct, aliasMap) + replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = aggregate.copy(child = Filter(pushDownPredicate, grandChild)) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val withPushdown = aggregate.copy(child = Filter(replaced, grandChild)) stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) } else { filter diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0290fafe879f..0128c220baac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -697,4 +697,57 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("aggregate: push down filters with alias") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where(('c === 2L || 'aa > 4) && 'aa < 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where('a + 1 < 3) + .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) + .where('c === 2L || 'aa > 4) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: push down filters with literal") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L && 'd === "s") + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where("s" === "s") + .groupBy('a)('a, count('b) as 'c, "s" as 'd) + .where('c === 2L) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("aggregate: don't push down filters which is nondeterministic") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) + .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + .analyze + + comparePlans(optimized, correctAnswer) + } } From 001f0528a851ac314b390e65eb0583f89e69a949 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 26 Nov 2015 01:15:05 -0800 Subject: [PATCH 1564/1824] [SPARK-12005][SQL] Work around VerifyError in HyperLogLogPlusPlus. Just move the code around a bit; that seems to make the JVM happy. Author: Marcelo Vanzin Closes #9985 from vanzin/SPARK-12005. --- .../expressions/aggregate/HyperLogLogPlusPlus.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 8a95c541f1e8..e1fd22e36764 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -63,11 +63,7 @@ case class HyperLogLogPlusPlus( def this(child: Expression, relativeSD: Expression) = { this( child = child, - relativeSD = relativeSD match { - case Literal(d: Double, DoubleType) => d - case _ => - throw new AnalysisException("The second argument should be a double literal.") - }, + relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD), mutableAggBufferOffset = 0, inputAggBufferOffset = 0) } @@ -448,4 +444,11 @@ object HyperLogLogPlusPlus { Array(189083, 185696.913, 182348.774, 179035.946, 175762.762, 172526.444, 169329.754, 166166.099, 163043.269, 159958.91, 156907.912, 153906.845, 150924.199, 147996.568, 145093.457, 142239.233, 139421.475, 136632.27, 133889.588, 131174.2, 128511.619, 125868.621, 123265.385, 120721.061, 118181.769, 115709.456, 113252.446, 110840.198, 108465.099, 106126.164, 103823.469, 101556.618, 99308.004, 97124.508, 94937.803, 92833.731, 90745.061, 88677.627, 86617.47, 84650.442, 82697.833, 80769.132, 78879.629, 77014.432, 75215.626, 73384.587, 71652.482, 69895.93, 68209.301, 66553.669, 64921.981, 63310.323, 61742.115, 60205.018, 58698.658, 57190.657, 55760.865, 54331.169, 52908.167, 51550.273, 50225.254, 48922.421, 47614.533, 46362.049, 45098.569, 43926.083, 42736.03, 41593.473, 40425.26, 39316.237, 38243.651, 37170.617, 36114.609, 35084.19, 34117.233, 33206.509, 32231.505, 31318.728, 30403.404, 29540.0550000001, 28679.236, 27825.862, 26965.216, 26179.148, 25462.08, 24645.952, 23922.523, 23198.144, 22529.128, 21762.4179999999, 21134.779, 20459.117, 19840.818, 19187.04, 18636.3689999999, 17982.831, 17439.7389999999, 16874.547, 16358.2169999999, 15835.684, 15352.914, 14823.681, 14329.313, 13816.897, 13342.874, 12880.882, 12491.648, 12021.254, 11625.392, 11293.7610000001, 10813.697, 10456.209, 10099.074, 9755.39000000001, 9393.18500000006, 9047.57900000003, 8657.98499999999, 8395.85900000005, 8033, 7736.95900000003, 7430.59699999995, 7258.47699999996, 6924.58200000005, 6691.29399999999, 6357.92500000005, 6202.05700000003, 5921.19700000004, 5628.28399999999, 5404.96799999999, 5226.71100000001, 4990.75600000005, 4799.77399999998, 4622.93099999998, 4472.478, 4171.78700000001, 3957.46299999999, 3868.95200000005, 3691.14300000004, 3474.63100000005, 3341.67200000002, 3109.14000000001, 3071.97400000005, 2796.40399999998, 2756.17799999996, 2611.46999999997, 2471.93000000005, 2382.26399999997, 2209.22400000005, 2142.28399999999, 2013.96100000001, 1911.18999999994, 1818.27099999995, 1668.47900000005, 1519.65800000005, 1469.67599999998, 1367.13800000004, 1248.52899999998, 1181.23600000003, 1022.71900000004, 1088.20700000005, 959.03600000008, 876.095999999903, 791.183999999892, 703.337000000058, 731.949999999953, 586.86400000006, 526.024999999907, 323.004999999888, 320.448000000091, 340.672999999952, 309.638999999966, 216.601999999955, 102.922999999952, 19.2399999999907, -0.114000000059605, -32.6240000000689, -89.3179999999702, -153.497999999905, -64.2970000000205, -143.695999999996, -259.497999999905, -253.017999999924, -213.948000000091, -397.590000000084, -434.006000000052, -403.475000000093, -297.958000000101, -404.317000000039, -528.898999999976, -506.621000000043, -513.205000000075, -479.351000000024, -596.139999999898, -527.016999999993, -664.681000000099, -680.306000000099, -704.050000000047, -850.486000000034, -757.43200000003, -713.308999999892) ) // scalastyle:on + + private def validateDoubleLiteral(exp: Expression): Double = exp match { + case Literal(d: Double, DoubleType) => d + case _ => + throw new AnalysisException("The second argument should be a double literal.") + } + } From bc16a67562560c732833260cbc34825f7e9dcb8f Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Nov 2015 11:31:28 -0800 Subject: [PATCH 1565/1824] [SPARK-11863][SQL] Unable to resolve order by if it contains mixture of aliases and real columns this is based on https://github.com/apache/spark/pull/9844, with some bug fix and clean up. The problems is that, normal operator should be resolved based on its child, but `Sort` operator can also be resolved based on its grandchild. So we have 3 rules that can resolve `Sort`: `ResolveReferences`, `ResolveSortReferences`(if grandchild is `Project`) and `ResolveAggregateFunctions`(if grandchild is `Aggregate`). For example, `select c1 as a , c2 as b from tab group by c1, c2 order by a, c2`, we need to resolve `a` and `c2` for `Sort`. Firstly `a` will be resolved in `ResolveReferences` based on its child, and when we reach `ResolveAggregateFunctions`, we will try to resolve both `a` and `c2` based on its grandchild, but failed because `a` is not a legal aggregate expression. whoever merge this PR, please give the credit to dilipbiswal Author: Dilip Biswal Author: Wenchen Fan Closes #9961 from cloud-fan/sort. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 13 ++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 47962ebe6ef8..94ffbbb2e5c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -630,7 +630,8 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { - val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = @@ -663,13 +664,19 @@ class Analyzer( } } + val sortOrdersMap = unresolvedSortOrders + .map(new TreeNodeRef(_)) + .zip(evaluatedOrderings) + .toMap + val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) + // Since we don't rely on sort.resolved as the stop condition for this rule, // we need to check this and prevent applying this rule multiple times - if (sortOrder == evaluatedOrderings) { + if (sortOrder == finalSortOrders) { sort } else { Project(aggregate.output, - Sort(evaluatedOrderings, global, + Sort(finalSortOrders, global, aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e05106995188..aeeca802d8bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -220,6 +220,24 @@ class AnalysisSuite extends AnalysisTest { // checkUDF(udf4, expected4) } + test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") { + val a = testRelation2.output(0) + val c = testRelation2.output(2) + val alias1 = a.as("a1") + val alias2 = c.as("a2") + val alias3 = count(a).as("a3") + + val plan = testRelation2 + .groupBy('a, 'c)('a.as("a1"), 'c.as("a2"), count('a).as("a3")) + .orderBy('a1.asc, 'c.asc) + + val expected = testRelation2 + .groupBy(a, c)(alias1, alias2, alias3) + .orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc) + .select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute) + checkAnalysis(plan, expected) + } + test("analyzer should replace current_timestamp with literals") { val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), LocalRelation()) From ad76562390b81207f8f32491c0bd8ad0e020141f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 26 Nov 2015 16:20:08 -0800 Subject: [PATCH 1566/1824] [SPARK-11998][SQL][TEST-HADOOP2.0] When downloading Hadoop artifacts from maven, we need to try to download the version that is used by Spark If we need to download Hive/Hadoop artifacts, try to download a Hadoop that matches the Hadoop used by Spark. If the Hadoop artifact cannot be resolved (e.g. Hadoop version is a vendor specific version like 2.0.0-cdh4.1.1), we will use Hadoop 2.4.0 (we used to hard code this version as the hadoop that we will download from maven) and we will not share Hadoop classes. I tested this match in my laptop with the following confs (these confs are used by our builds). All tests are good. ``` build/sbt -Phadoop-1 -Dhadoop.version=1.2.1 -Pkinesis-asl -Phive-thriftserver -Phive build/sbt -Phadoop-1 -Dhadoop.version=2.0.0-mr1-cdh4.1.1 -Pkinesis-asl -Phive-thriftserver -Phive build/sbt -Pyarn -Phadoop-2.2 -Pkinesis-asl -Phive-thriftserver -Phive build/sbt -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Pkinesis-asl -Phive-thriftserver -Phive ``` Author: Yin Huai Closes #9979 from yhuai/versionsSuite. --- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../hive/client/IsolatedClientLoader.scala | 62 +++++++++++++++---- .../spark/sql/hive/client/VersionsSuite.scala | 23 +++++-- 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8a4264194ae8..e83941c2ecf6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.hadoop.util.VersionInfo import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.SQLConf.SQLConfEntry @@ -288,7 +289,8 @@ class HiveContext private[hive]( logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") IsolatedClientLoader.forVersion( - version = hiveMetastoreVersion, + hiveMetastoreVersion = hiveMetastoreVersion, + hadoopVersion = VersionInfo.getVersion, config = allConfig, barrierPrefixes = hiveMetastoreBarrierPrefixes, sharedPrefixes = hiveMetastoreSharedPrefixes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index e041e0d8e5ae..010051d255fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -34,23 +34,51 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.{MutableURLClassLoader, Utils} /** Factory for `IsolatedClientLoader` with specific versions of hive. */ -private[hive] object IsolatedClientLoader { +private[hive] object IsolatedClientLoader extends Logging { /** * Creates isolated Hive client loaders by downloading the requested version from maven. */ def forVersion( - version: String, + hiveMetastoreVersion: String, + hadoopVersion: String, config: Map[String, String] = Map.empty, ivyPath: Option[String] = None, sharedPrefixes: Seq[String] = Seq.empty, barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { - val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, - downloadVersion(resolvedVersion, ivyPath)) + val resolvedVersion = hiveVersion(hiveMetastoreVersion) + // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact + // with the given version, we will use Hadoop 2.4.0 and then will not share Hadoop classes. + var sharesHadoopClasses = true + val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { + resolvedVersions((resolvedVersion, hadoopVersion)) + } else { + val (downloadedFiles, actualHadoopVersion) = + try { + (downloadVersion(resolvedVersion, hadoopVersion, ivyPath), hadoopVersion) + } catch { + case e: RuntimeException if e.getMessage.contains("hadoop") => + // If the error message contains hadoop, it is probably because the hadoop + // version cannot be resolved (e.g. it is a vendor specific version like + // 2.0.0-cdh4.1.1). If it is the case, we will try just + // "org.apache.hadoop:hadoop-client:2.4.0". "org.apache.hadoop:hadoop-client:2.4.0" + // is used just because we used to hard code it as the hadoop artifact to download. + logWarning(s"Failed to resolve Hadoop artifacts for the version ${hadoopVersion}. " + + s"We will change the hadoop version from ${hadoopVersion} to 2.4.0 and try again. " + + "Hadoop classes will not be shared between Spark and Hive metastore client. " + + "It is recommended to set jars used by Hive metastore client through " + + "spark.sql.hive.metastore.jars in the production environment.") + sharesHadoopClasses = false + (downloadVersion(resolvedVersion, "2.4.0", ivyPath), "2.4.0") + } + resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) + resolvedVersions((resolvedVersion, actualHadoopVersion)) + } + new IsolatedClientLoader( - version = hiveVersion(version), + version = hiveVersion(hiveMetastoreVersion), execJars = files, config = config, + sharesHadoopClasses = sharesHadoopClasses, sharedPrefixes = sharedPrefixes, barrierPrefixes = barrierPrefixes) } @@ -64,12 +92,15 @@ private[hive] object IsolatedClientLoader { case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 } - private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { + private def downloadVersion( + version: HiveVersion, + hadoopVersion: String, + ivyPath: Option[String]): Seq[URL] = { val hiveArtifacts = version.extraDeps ++ Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ Seq("com.google.guava:guava:14.0.1", - "org.apache.hadoop:hadoop-client:2.4.0") + s"org.apache.hadoop:hadoop-client:$hadoopVersion") val classpath = quietly { SparkSubmitUtils.resolveMavenCoordinates( @@ -86,7 +117,10 @@ private[hive] object IsolatedClientLoader { tempDir.listFiles().map(_.toURI.toURL) } - private def resolvedVersions = new scala.collection.mutable.HashMap[HiveVersion, Seq[URL]] + // A map from a given pair of HiveVersion and Hadoop version to jar files. + // It is only used by forVersion. + private val resolvedVersions = + new scala.collection.mutable.HashMap[(HiveVersion, String), Seq[URL]] } /** @@ -106,6 +140,7 @@ private[hive] object IsolatedClientLoader { * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. + * @param sharesHadoopClasses When true, we will share Hadoop classes between Spark and * @param rootClassLoader The system root classloader. Must not know about Hive classes. * @param baseClassLoader The spark classloader that is used to load shared classes. */ @@ -114,6 +149,7 @@ private[hive] class IsolatedClientLoader( val execJars: Seq[URL] = Seq.empty, val config: Map[String, String] = Map.empty, val isolationOn: Boolean = true, + val sharesHadoopClasses: Boolean = true, val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader, val sharedPrefixes: Seq[String] = Seq.empty, @@ -126,16 +162,20 @@ private[hive] class IsolatedClientLoader( /** All jars used by the hive specific classloader. */ protected def allJars = execJars.toArray - protected def isSharedClass(name: String): Boolean = + protected def isSharedClass(name: String): Boolean = { + val isHadoopClass = + name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") + name.contains("slf4j") || name.contains("log4j") || name.startsWith("org.apache.spark.") || - (name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.")) || + (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || (name.startsWith("com.google") && !name.startsWith("com.google.cloud")) || name.startsWith("java.lang.") || name.startsWith("java.net") || sharedPrefixes.exists(name.startsWith) + } /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 7bc13bc60d30..502b240f3650 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.client import java.io.File +import org.apache.hadoop.util.VersionInfo + import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} @@ -53,9 +55,11 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion(HiveContext.hiveExecutionVersion, - buildConf(), - ivyPath).createClient() + val badClient = IsolatedClientLoader.forVersion( + hiveMetastoreVersion = HiveContext.hiveExecutionVersion, + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -85,7 +89,11 @@ class VersionsSuite extends SparkFunSuite with Logging { ignore("failure sanity check") { val e = intercept[Throwable] { val badClient = quietly { - IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).createClient() + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = "13", + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") @@ -99,7 +107,12 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: create client") { client = null System.gc() // Hack to avoid SEGV on some JVM versions. - client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).createClient() + client = + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = version, + hadoopVersion = VersionInfo.getVersion, + config = buildConf(), + ivyPath = ivyPath).createClient() } test(s"$version: createDatabase") { From de28e4d4deca385b7c40b3a6a1efcd6e2fec2f9b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 26 Nov 2015 18:47:54 -0800 Subject: [PATCH 1567/1824] [SPARK-11973][SQL] Improve optimizer code readability. This is a followup for https://github.com/apache/spark/pull/9959. I added more documentation and rewrote some monadic code into simpler ifs. Author: Reynold Xin Closes #9995 from rxin/SPARK-11973. --- .../sql/catalyst/optimizer/Optimizer.scala | 50 +++++++++---------- .../optimizer/FilterPushdownSuite.scala | 2 +- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 52f609bc158c..2901d8f2efdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -59,7 +59,7 @@ object DefaultOptimizer extends Optimizer { ConstantFolding, LikeSimplification, BooleanSimplification, - RemoveDispensable, + RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, SimplifyCaseConversionExpressions) :: @@ -660,14 +660,14 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp case filter @ Filter(condition, g: Generate) => // Predicates that reference attributes produced by the `Generate` operator cannot // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf g.child.outputSet + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.references subsetOf g.child.outputSet } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, + val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) } else { filter } @@ -675,34 +675,34 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp } /** - * Push [[Filter]] operators through [[Aggregate]] operators. Parts of the predicate that reference - * attributes which are subset of group by attribute set of [[Aggregate]] will be pushed beneath, - * and the rest should remain above. + * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only + * non-aggregate attributes (typically literals or grouping expressions). */ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, - aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - - def hasAggregate(expression: Expression): Boolean = expression match { - case agg: AggregateExpression => true - case other => expression.children.exists(hasAggregate) - } - // Create a map of Alias for expressions that does not have AggregateExpression - val aliasMap = AttributeMap(aggregateExpressions.collect { - case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child) + case filter @ Filter(condition, aggregate: Aggregate) => + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) }) - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct => - val replaced = replaceAlias(conjunct, aliasMap) - replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic + // For each filter, expand the alias and check if the filter can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + val replaced = replaceAlias(cond, aliasMap) + replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic } + if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) - val withPushdown = aggregate.copy(child = Filter(replaced, grandChild)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) + // If there is no more filter to stay up, just eliminate the filter. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) } else { filter } @@ -714,7 +714,7 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * And also pushes down the join filter, where the `condition` can be evaluated using only the * attributes of the left or right side of sub query when applicable. * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details @@ -821,7 +821,7 @@ object SimplifyCasts extends Rule[LogicalPlan] { /** * Removes nodes that are not necessary. */ -object RemoveDispensable extends Rule[LogicalPlan] { +object RemoveDispensableExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case UnaryPositive(child) => child case PromotePrecision(child) => child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0128c220baac..fba4c5ca77d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -734,7 +734,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("aggregate: don't push down filters which is nondeterministic") { + test("aggregate: don't push down filters that are nondeterministic") { val originalQuery = testRelation .select('a, 'b) .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) From 4376b5bea8171e4e73b3dbabbfdf84fa1afd140b Mon Sep 17 00:00:00 2001 From: muxator Date: Thu, 26 Nov 2015 18:52:20 -0800 Subject: [PATCH 1568/1824] doc typo: "classificaion" -> "classification" Author: muxator Closes #10008 from muxator/patch-1. --- docs/mllib-linear-methods.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 0c76e6e99946..132f8c354aa9 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -122,7 +122,7 @@ Under the hood, linear methods use convex optimization methods to optimize the o [Classification](http://en.wikipedia.org/wiki/Statistical_classification) aims to divide items into categories. The most common classification type is -[binary classificaion](http://en.wikipedia.org/wiki/Binary_classification), where there are two +[binary classification](http://en.wikipedia.org/wiki/Binary_classification), where there are two categories, usually named positive and negative. If there are more than two categories, it is called [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). From 0c1e72e7f79231e537299b57a1ab7cd843171923 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 26 Nov 2015 18:56:22 -0800 Subject: [PATCH 1569/1824] [SPARK-11996][CORE] Make the executor thread dump work again In the previous implementation, the driver needs to know the executor listening address to send the thread dump request. However, in Netty RPC, the executor doesn't listen to any port, so the executor thread dump feature is broken. This patch makes the driver use the endpointRef stored in BlockManagerMasterEndpoint to send the thread dump request to fix it. Author: Shixiong Zhu Closes #9976 from zsxwing/executor-thread-dump. --- .../scala/org/apache/spark/SparkContext.scala | 10 ++--- .../org/apache/spark/executor/Executor.scala | 5 --- .../spark/executor/ExecutorEndpoint.scala | 43 ------------------- .../spark/storage/BlockManagerMaster.scala | 4 +- .../storage/BlockManagerMasterEndpoint.scala | 12 +++--- .../spark/storage/BlockManagerMessages.scala | 7 ++- .../storage/BlockManagerSlaveEndpoint.scala | 7 ++- project/MimaExcludes.scala | 8 ++++ 8 files changed, 29 insertions(+), 67 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 2c10779f2b89..b030d3c71dc2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -48,20 +48,20 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ +import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{SparkUI, ConsoleProgressBar} import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ @@ -619,11 +619,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (executorId == SparkContext.DRIVER_IDENTIFIER) { Some(Utils.getThreadDump()) } else { - val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get - val endpointRef = env.rpcEnv.setupEndpointRef( - SparkEnv.executorActorSystemName, - RpcAddress(host, port), - ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) + val endpointRef = env.blockManager.master.getExecutorEndpointRef(executorId).get Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9e88d488c037..6154f06e3ac1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -85,10 +85,6 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) } - // Create an RpcEndpoint for receiving RPCs from the driver - private val executorEndpoint = env.rpcEnv.setupEndpoint( - ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) - // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) @@ -136,7 +132,6 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() - env.rpcEnv.stop(executorEndpoint) heartbeater.shutdown() heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala deleted file mode 100644 index cf362f846473..000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.Utils - -/** - * Driver -> Executor message to trigger a thread dump. - */ -private[spark] case object TriggerThreadDump - -/** - * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC. - */ -private[spark] -class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case TriggerThreadDump => - context.reply(Utils.getThreadDump()) - } - -} - -object ExecutorEndpoint { - val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint" -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f45bff34d4db..440c4c18aadd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -87,8 +87,8 @@ class BlockManagerMaster( driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) } - def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { - driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) + def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { + driverEndpoint.askWithRetry[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 7db6035553ae..41892b4ffce5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,7 +19,6 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.immutable.HashSet import scala.collection.mutable import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} @@ -75,8 +74,8 @@ class BlockManagerMasterEndpoint( case GetPeers(blockManagerId) => context.reply(getPeers(blockManagerId)) - case GetRpcHostPortForExecutor(executorId) => - context.reply(getRpcHostPortForExecutor(executorId)) + case GetExecutorEndpointRef(executorId) => + context.reply(getExecutorEndpointRef(executorId)) case GetMemoryStatus => context.reply(memoryStatus) @@ -388,15 +387,14 @@ class BlockManagerMasterEndpoint( } /** - * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its - * [[BlockManagerSlaveEndpoint]]. + * Returns an [[RpcEndpointRef]] of the [[BlockManagerSlaveEndpoint]] for sending RPC messages. */ - private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { + private def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); info <- blockManagerInfo.get(blockManagerId) ) yield { - (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port) + info.slaveEndpoint } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 376e9eb48843..f392a4a0cd9b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -42,6 +42,11 @@ private[spark] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave + /** + * Driver -> Executor message to trigger a thread dump. + */ + case object TriggerThreadDump extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// @@ -90,7 +95,7 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class GetExecutorEndpointRef(executorId: String) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index e749631bf6f1..9eca902f7454 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -19,10 +19,10 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * An RpcEndpoint to take commands from the master to execute options. For example, @@ -70,6 +70,9 @@ class BlockManagerSlaveEndpoint( case GetMatchingBlockIds(filter, _) => context.reply(blockManager.getMatchingBlockIds(filter)) + + case TriggerThreadDump => + context.reply(Utils.getThreadDump()) } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 54a9ad956d11..566bfe8efb7a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -147,6 +147,14 @@ object MimaExcludes { // SPARK-4557 Changed foreachRDD to use VoidFunction ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") + ) ++ Seq( + // SPARK-11996 Make the executor thread dump work again + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") ) case v if v.startsWith("1.5") => Seq( From 6f6bb0e893c8370cbab4d63a56d74e00cb7f3cf6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 26 Nov 2015 19:00:36 -0800 Subject: [PATCH 1570/1824] [SPARK-12011][SQL] Stddev/Variance etc should support columnName as arguments Spark SQL aggregate function: ```Java stddev stddev_pop stddev_samp variance var_pop var_samp skewness kurtosis collect_list collect_set ``` should support ```columnName``` as arguments like other aggregate function(max/min/count/sum). Author: Yanbo Liang Closes #9994 from yanboliang/SPARK-12011. --- .../org/apache/spark/sql/functions.scala | 86 +++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 3 + 2 files changed, 89 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 276c5dfc8b06..e79defbbbdee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -214,6 +214,16 @@ object functions extends LegacyFunctions { */ def collect_list(e: Column): Column = callUDF("collect_list", e) + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(columnName: String): Column = collect_list(Column(columnName)) + /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * @@ -224,6 +234,16 @@ object functions extends LegacyFunctions { */ def collect_set(e: Column): Column = callUDF("collect_set", e) + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(columnName: String): Column = collect_set(Column(columnName)) + /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. * @@ -312,6 +332,14 @@ object functions extends LegacyFunctions { */ def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) } + /** + * Aggregate function: returns the kurtosis of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) + /** * Aggregate function: returns the last value in a group. * @@ -386,6 +414,14 @@ object functions extends LegacyFunctions { */ def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) } + /** + * Aggregate function: returns the skewness of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def skewness(columnName: String): Column = skewness(Column(columnName)) + /** * Aggregate function: alias for [[stddev_samp]]. * @@ -394,6 +430,14 @@ object functions extends LegacyFunctions { */ def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + /** + * Aggregate function: alias for [[stddev_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(columnName: String): Column = stddev(Column(columnName)) + /** * Aggregate function: returns the sample standard deviation of * the expression in a group. @@ -403,6 +447,15 @@ object functions extends LegacyFunctions { */ def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + /** + * Aggregate function: returns the sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName)) + /** * Aggregate function: returns the population standard deviation of * the expression in a group. @@ -412,6 +465,15 @@ object functions extends LegacyFunctions { */ def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) } + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName)) + /** * Aggregate function: returns the sum of all values in the expression. * @@ -452,6 +514,14 @@ object functions extends LegacyFunctions { */ def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + /** + * Aggregate function: alias for [[var_samp]]. + * + * @group agg_funcs + * @since 1.6.0 + */ + def variance(columnName: String): Column = variance(Column(columnName)) + /** * Aggregate function: returns the unbiased variance of the values in a group. * @@ -460,6 +530,14 @@ object functions extends LegacyFunctions { */ def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_samp(columnName: String): Column = var_samp(Column(columnName)) + /** * Aggregate function: returns the population variance of the values in a group. * @@ -468,6 +546,14 @@ object functions extends LegacyFunctions { */ def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } + /** + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def var_pop(columnName: String): Column = var_pop(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9c42f65bb6f5..b5c636d0de1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -261,6 +261,9 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) + checkAnswer( + testData2.agg(stddev("a"), stddev_pop("a"), stddev_samp("a")), + Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) } test("zero stddev") { From b63938a8b04a30feb6b2255c4d4e530a74855afc Mon Sep 17 00:00:00 2001 From: mariusvniekerk Date: Thu, 26 Nov 2015 19:13:16 -0800 Subject: [PATCH 1571/1824] [SPARK-11881][SQL] Fix for postgresql fetchsize > 0 Reference: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor In order for PostgreSQL to honor the fetchSize non-zero setting, its Connection.autoCommit needs to be set to false. Otherwise, it will just quietly ignore the fetchSize setting. This adds a new side-effecting dialect specific beforeFetch method that will fire before a select query is ran. Author: mariusvniekerk Closes #9861 from mariusvniekerk/SPARK-11881. --- .../execution/datasources/jdbc/JDBCRDD.scala | 12 ++++++++++++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 11 +++++++++++ .../apache/spark/sql/jdbc/PostgresDialect.scala | 17 ++++++++++++++++- 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 89c850ce238d..f9b72597dd2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -224,6 +224,7 @@ private[sql] object JDBCRDD extends Logging { quotedColumns, filters, parts, + url, properties) } } @@ -241,6 +242,7 @@ private[sql] class JDBCRDD( columns: Array[String], filters: Array[Filter], partitions: Array[Partition], + url: String, properties: Properties) extends RDD[InternalRow](sc, Nil) { @@ -361,6 +363,9 @@ private[sql] class JDBCRDD( context.addTaskCompletionListener{ context => close() } val part = thePart.asInstanceOf[JDBCPartition] val conn = getConnection() + val dialect = JdbcDialects.get(url) + import scala.collection.JavaConverters._ + dialect.beforeFetch(conn, properties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to @@ -489,6 +494,13 @@ private[sql] class JDBCRDD( } try { if (null != conn) { + if (!conn.getAutoCommit && !conn.isClosed) { + try { + conn.commit() + } catch { + case e: Throwable => logWarning("Exception committing transaction", e) + } + } conn.close() } logInfo("closed connection") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index b3b2cb6178c5..13db141f27db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.jdbc +import java.sql.Connection + import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -97,6 +99,15 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param properties The connection properties. This is passed through from the relation. + */ + def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + } + } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index ed3faa126863..3cf80f576e92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Connection, Types} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types._ @@ -70,4 +70,19 @@ private object PostgresDialect extends JdbcDialect { override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } + + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + super.beforeFetch(connection, properties) + + // According to the postgres jdbc documentation we need to be in autocommit=false if we actually + // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the + // rows inside the driver when fetching. + // + // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor + // + if (properties.getOrElse("fetchsize", "0").toInt > 0) { + connection.setAutoCommit(false) + } + + } } From d8220885c492141dfc61e8ffb92934f2339fe8d3 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 26 Nov 2015 19:15:22 -0800 Subject: [PATCH 1572/1824] [SPARK-11917][PYSPARK] Add SQLContext#dropTempTable to PySpark Author: Jeff Zhang Closes #9903 from zjffdu/SPARK-11917. --- python/pyspark/sql/context.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index a49c1b58d018..b05aa2f5c4cd 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -445,6 +445,15 @@ def registerDataFrameAsTable(self, df, tableName): else: raise ValueError("Can only register DataFrame as table") + @since(1.6) + def dropTempTable(self, tableName): + """ Remove the temp table from catalog. + + >>> sqlContext.registerDataFrameAsTable(df, "table1") + >>> sqlContext.dropTempTable("table1") + """ + self._ssql_ctx.dropTempTable(tableName) + def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. From 4d4cbc034bef559f47f8b74cecd8196dc8a85348 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 26 Nov 2015 19:17:46 -0800 Subject: [PATCH 1573/1824] [SPARK-11778][SQL] add regression test Fix regression test for SPARK-11778. marmbrus Could you please take a look? Thank you very much!! Author: Huaxin Gao Closes #9890 from huaxingao/spark-11778-regression-test. --- .../hive/HiveDataFrameAnalyticsSuite.scala | 10 ------ .../spark/sql/hive/HiveDataFrameSuite.scala | 32 +++++++++++++++++++ 2 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index f19a74d4b372..9864acf76526 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -34,14 +34,10 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with override def beforeAll() { testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") - hiveContext.sql("create schema usrdb") - hiveContext.sql("create table usrdb.test(c1 int)") } override def afterAll(): Unit = { hiveContext.dropTempTable("mytable") - hiveContext.sql("drop table usrdb.test") - hiveContext.sql("drop schema usrdb") } test("rollup") { @@ -78,10 +74,4 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } - - // There was a bug in DataFrameFrameReader.table and it has problem for table with schema name, - // Before fix, it throw Exceptionorg.apache.spark.sql.catalyst.analysis.NoSuchTableException - test("table name with schema") { - hiveContext.read.table("usrdb.test") - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala new file mode 100644 index 000000000000..7fdc5d71937f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.QueryTest + +class HiveDataFrameSuite extends QueryTest with TestHiveSingleton { + test("table name with schema") { + // regression test for SPARK-11778 + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c int)") + hiveContext.read.table("usrdb.test") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") + } +} From 5eaed4e45c6c57e995ac7438016fad545716e596 Mon Sep 17 00:00:00 2001 From: Jeremy Derr Date: Thu, 26 Nov 2015 19:25:13 -0800 Subject: [PATCH 1574/1824] [SPARK-11991] fixes If `--private-ips` is required but not provided, spark_ec2.py may behave inappropriately, including attempting to ssh to localhost in attempts to verify ssh connectivity to the cluster. This fixes that behavior by raising a `UsageError` exception if `get_dns_name` is unable to determine a hostname as a result. Author: Jeremy Derr Closes #9975 from jcderr/SPARK-11991/ec_spark.py_hostname_check. --- ec2/spark_ec2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9fd652a3df4c..84a950c9f652 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -1242,6 +1242,10 @@ def get_ip_address(instance, private_ips=False): def get_dns_name(instance, private_ips=False): dns = instance.public_dns_name if not private_ips else \ instance.private_ip_address + if not dns: + raise UsageError("Failed to determine hostname of {0}.\n" + "Please check that you provided --private-ips if " + "necessary".format(instance)) return dns From 10e315c28c933b967674ae51e1b2f24160c2e8a5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 26 Nov 2015 19:36:43 -0800 Subject: [PATCH 1575/1824] Fix style violation for b63938a8b04 --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index f9b72597dd2a..57a8a044a37c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties +import scala.util.control.NonFatal + import org.apache.commons.lang3.StringUtils import org.apache.spark.rdd.RDD @@ -498,7 +500,7 @@ private[sql] class JDBCRDD( try { conn.commit() } catch { - case e: Throwable => logWarning("Exception committing transaction", e) + case NonFatal(e) => logWarning("Exception committing transaction", e) } } conn.close() From a374e20b5492c775f20d32e8fbddadbd8098a655 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 26 Nov 2015 21:04:40 -0800 Subject: [PATCH 1576/1824] [SPARK-11997] [SQL] NPE when save a DataFrame as parquet and partitioned by long column Check for partition column null-ability while building the partition spec. Author: Dilip Biswal Closes #10001 from dilipbiswal/spark-11997. --- .../org/apache/spark/sql/sources/interfaces.scala | 2 +- .../datasources/parquet/ParquetQuerySuite.scala | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f9465157c936..9ace25dc7d21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -607,7 +607,7 @@ abstract class HadoopFsRelation private[sql]( def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => Cast( - Literal.create(row.getString(i), StringType), + Literal.create(row.getUTF8String(i), StringType), userProvidedSchema.fields(i).dataType).eval() }: _*) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 70fae32b7e7a..f777e973052d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -252,6 +252,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-11997 parquet with null partition values") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext.range(1, 3) + .selectExpr("if(id % 2 = 0, null, id) AS n", "id") + .write.partitionBy("n").parquet(path) + + checkAnswer( + sqlContext.read.parquet(path).filter("n is null"), + Row(2, null)) + } + } + // This test case is ignored because of parquet-mr bug PARQUET-370 ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { withTempPath { dir => From ba02f6cb5a40511cefa511d410be93c035d43f23 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 27 Nov 2015 11:48:01 -0800 Subject: [PATCH 1577/1824] [SPARK-12025][SPARKR] Rename some window rank function names for SparkR Change ```cumeDist -> cume_dist, denseRank -> dense_rank, percentRank -> percent_rank, rowNumber -> row_number``` at SparkR side. There are two reasons that we should make this change: * We should follow the [naming convention rule of R](http://www.inside-r.org/node/230645) * Spark DataFrame has deprecated the old convention (such as ```cumeDist```) and will remove it in Spark 2.0. It's better to fix this issue before 1.6 release, otherwise we will make breaking API change. cc shivaram sun-rui Author: Yanbo Liang Closes #10016 from yanboliang/SPARK-12025. --- R/pkg/NAMESPACE | 8 ++--- R/pkg/R/functions.R | 54 ++++++++++++++++---------------- R/pkg/R/generics.R | 16 +++++----- R/pkg/inst/tests/test_sparkSQL.R | 4 +-- 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 260c9edce62e..5d04dd6acaab 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -123,14 +123,14 @@ exportMethods("%in%", "count", "countDistinct", "crc32", - "cumeDist", + "cume_dist", "date_add", "date_format", "date_sub", "datediff", "dayofmonth", "dayofyear", - "denseRank", + "dense_rank", "desc", "endsWith", "exp", @@ -188,7 +188,7 @@ exportMethods("%in%", "next_day", "ntile", "otherwise", - "percentRank", + "percent_rank", "pmod", "quarter", "rand", @@ -200,7 +200,7 @@ exportMethods("%in%", "rint", "rlike", "round", - "rowNumber", + "row_number", "rpad", "rtrim", "second", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 25a1f2210149..e98e7a0117ca 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2146,47 +2146,47 @@ setMethod("ifelse", ###################### Window functions###################### -#' cumeDist +#' cume_dist #' #' Window function: returns the cumulative distribution of values within a window partition, #' i.e. the fraction of rows that are below the current row. #' #' N = total number of rows in the partition -#' cumeDist(x) = number of values before (and including) x / N +#' cume_dist(x) = number of values before (and including) x / N #' #' This is equivalent to the CUME_DIST function in SQL. #' -#' @rdname cumeDist -#' @name cumeDist +#' @rdname cume_dist +#' @name cume_dist #' @family window_funcs #' @export -#' @examples \dontrun{cumeDist()} -setMethod("cumeDist", +#' @examples \dontrun{cume_dist()} +setMethod("cume_dist", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "cumeDist") + jc <- callJStatic("org.apache.spark.sql.functions", "cume_dist") column(jc) }) -#' denseRank +#' dense_rank #' #' Window function: returns the rank of rows within a window partition, without any gaps. -#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking -#' sequence when there are ties. That is, if you were ranking a competition using denseRank +#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking +#' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. #' #' This is equivalent to the DENSE_RANK function in SQL. #' -#' @rdname denseRank -#' @name denseRank +#' @rdname dense_rank +#' @name dense_rank #' @family window_funcs #' @export -#' @examples \dontrun{denseRank()} -setMethod("denseRank", +#' @examples \dontrun{dense_rank()} +setMethod("dense_rank", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "denseRank") + jc <- callJStatic("org.apache.spark.sql.functions", "dense_rank") column(jc) }) @@ -2264,7 +2264,7 @@ setMethod("ntile", column(jc) }) -#' percentRank +#' percent_rank #' #' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. #' @@ -2274,15 +2274,15 @@ setMethod("ntile", #' #' This is equivalent to the PERCENT_RANK function in SQL. #' -#' @rdname percentRank -#' @name percentRank +#' @rdname percent_rank +#' @name percent_rank #' @family window_funcs #' @export -#' @examples \dontrun{percentRank()} -setMethod("percentRank", +#' @examples \dontrun{percent_rank()} +setMethod("percent_rank", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "percentRank") + jc <- callJStatic("org.apache.spark.sql.functions", "percent_rank") column(jc) }) @@ -2316,21 +2316,21 @@ setMethod("rank", base::rank(x, ...) }) -#' rowNumber +#' row_number #' #' Window function: returns a sequential number starting at 1 within a window partition. #' #' This is equivalent to the ROW_NUMBER function in SQL. #' -#' @rdname rowNumber -#' @name rowNumber +#' @rdname row_number +#' @name row_number #' @family window_funcs #' @export -#' @examples \dontrun{rowNumber()} -setMethod("rowNumber", +#' @examples \dontrun{row_number()} +setMethod("row_number", signature(x = "missing"), function() { - jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber") + jc <- callJStatic("org.apache.spark.sql.functions", "row_number") column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 1b3f10ea0464..0c305441e043 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -700,9 +700,9 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) -#' @rdname cumeDist +#' @rdname cume_dist #' @export -setGeneric("cumeDist", function(x) { standardGeneric("cumeDist") }) +setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) #' @rdname datediff #' @export @@ -728,9 +728,9 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @export setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) -#' @rdname denseRank +#' @rdname dense_rank #' @export -setGeneric("denseRank", function(x) { standardGeneric("denseRank") }) +setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) #' @rdname explode #' @export @@ -872,9 +872,9 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @rdname percentRank +#' @rdname percent_rank #' @export -setGeneric("percentRank", function(x) { standardGeneric("percentRank") }) +setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") }) #' @rdname pmod #' @export @@ -913,9 +913,9 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @export setGeneric("rint", function(x, ...) { standardGeneric("rint") }) -#' @rdname rowNumber +#' @rdname row_number #' @export -setGeneric("rowNumber", function(x) { standardGeneric("rowNumber") }) +setGeneric("row_number", function(x) { standardGeneric("row_number") }) #' @rdname rpad #' @export diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 3f4f319fe745..0fbe0658265b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -861,8 +861,8 @@ test_that("column functions", { c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) c12 <- variance(c) c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) - c14 <- cumeDist() + ntile(1) - c15 <- denseRank() + percentRank() + rank() + rowNumber() + c14 <- cume_dist() + ntile(1) + c15 <- dense_rank() + percent_rank() + rank() + row_number() # Test if base::rank() is exposed expect_equal(class(rank())[[1]], "Column") From f57e6c9effdb9e282fc8ae66dc30fe053fed5272 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 27 Nov 2015 11:50:18 -0800 Subject: [PATCH 1578/1824] [SPARK-12021][STREAMING][TESTS] Fix the potential dead-lock in StreamingListenerSuite In StreamingListenerSuite."don't call ssc.stop in listener", after the main thread calls `ssc.stop()`, `StreamingContextStoppingCollector` may call `ssc.stop()` in the listener bus thread, which is a dead-lock. This PR updated `StreamingContextStoppingCollector` to only call `ssc.stop()` in the first batch to avoid the dead-lock. Author: Shixiong Zhu Closes #10011 from zsxwing/fix-test-deadlock. --- .../streaming/StreamingListenerSuite.scala | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index df4575ab25aa..04cd5bdc26be 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -222,7 +222,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val batchCounter = new BatchCounter(_ssc) _ssc.start() // Make sure running at least one batch - batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000) + if (!batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000)) { + fail("The first batch cannot complete in 10 seconds") + } + // When reaching here, we can make sure `StreamingContextStoppingCollector` won't call + // `ssc.stop()`, so it's safe to call `_ssc.stop()` now. _ssc.stop() assert(contextStoppingCollector.sparkExSeen) } @@ -345,12 +349,21 @@ class FailureReasonsCollector extends StreamingListener { */ class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener { @volatile var sparkExSeen = false + + private var isFirstBatch = true + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { - try { - ssc.stop() - } catch { - case se: SparkException => - sparkExSeen = true + if (isFirstBatch) { + // We should only call `ssc.stop()` in the first batch. Otherwise, it's possible that the main + // thread is calling `ssc.stop()`, while StreamingContextStoppingCollector is also calling + // `ssc.stop()` in the listener thread, which becomes a dead-lock. + isFirstBatch = false + try { + ssc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } } } } From b9921524d970f9413039967c1f17ae2e736982f0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 27 Nov 2015 15:11:13 -0800 Subject: [PATCH 1579/1824] [SPARK-12020][TESTS][TEST-HADOOP2.0] PR builder cannot trigger hadoop 2.0 test https://issues.apache.org/jira/browse/SPARK-12020 Author: Yin Huai Closes #10010 from yhuai/SPARK-12020. --- dev/run-tests-jenkins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 623004310e18..4f390ef1eaa3 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -164,7 +164,7 @@ def main(): # Switch the Hadoop profile based on the PR title: if "test-hadoop1.0" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop1.0" - if "test-hadoop2.2" in ghprb_pull_title: + if "test-hadoop2.0" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.0" if "test-hadoop2.2" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.2" From 149cd692ee2e127d79386fd8e584f4f70a2906ba Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 27 Nov 2015 22:44:08 -0800 Subject: [PATCH 1580/1824] [SPARK-12028] [SQL] get_json_object returns an incorrect result when the value is null literals When calling `get_json_object` for the following two cases, both results are `"null"`: ```scala val tuple: Seq[(String, String)] = ("5", """{"f1": null}""") :: Nil val df: DataFrame = tuple.toDF("key", "jstring") val res = df.select(functions.get_json_object($"jstring", "$.f1")).collect() ``` ```scala val tuple2: Seq[(String, String)] = ("5", """{"f1": "null"}""") :: Nil val df2: DataFrame = tuple2.toDF("key", "jstring") val res3 = df2.select(functions.get_json_object($"jstring", "$.f1")).collect() ``` Fixed the problem and also added a test case. Author: gatorsmile Closes #10018 from gatorsmile/get_json_object. --- .../expressions/jsonExpressions.scala | 7 +++++-- .../apache/spark/sql/JsonFunctionsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8cd73236a787..4991b9cb54e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -298,8 +298,11 @@ case class GetJsonObject(json: Expression, path: Expression) case (FIELD_NAME, Named(name) :: xs) if p.getCurrentName == name => // exact field match - p.nextToken() - evaluatePath(p, g, style, xs) + if (p.nextToken() != JsonToken.VALUE_NULL) { + evaluatePath(p, g, style, xs) + } else { + false + } case (FIELD_NAME, Wildcard :: xs) => // wildcard field match diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 14fd56fc8c22..1f384edf321b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -39,6 +39,26 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { ("6", "[invalid JSON string]") :: Nil + test("function get_json_object - null") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: + Nil + + checkAnswer( + df.select($"key", functions.get_json_object($"jstring", "$.f1"), + functions.get_json_object($"jstring", "$.f2"), + functions.get_json_object($"jstring", "$.f3"), + functions.get_json_object($"jstring", "$.f4"), + functions.get_json_object($"jstring", "$.f5")), + expected) + } + test("json_tuple select") { val df: DataFrame = tuples.toDF("key", "jstring") val expected = From 28e46ab46368ea3833c8e805163893bbb6f2a265 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sat, 28 Nov 2015 21:02:05 -0800 Subject: [PATCH 1581/1824] [SPARK-12029][SPARKR] Improve column functions signature, param check, tests, fix doc and add examples shivaram sun-rui Author: felixcheung Closes #10019 from felixcheung/rfunctionsdoc. --- R/pkg/R/functions.R | 121 +++++++++++++++++++++++-------- R/pkg/inst/tests/test_sparkSQL.R | 9 ++- 2 files changed, 96 insertions(+), 34 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index e98e7a0117ca..b30331c61c9a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -878,7 +878,7 @@ setMethod("rtrim", #'} setMethod("sd", signature(x = "Column"), - function(x, na.rm = FALSE) { + function(x) { # In R, sample standard deviation is calculated with the sd() function. stddev_samp(x) }) @@ -1250,7 +1250,7 @@ setMethod("upper", #'} setMethod("var", signature(x = "Column"), - function(x, y = NULL, na.rm = FALSE, use) { + function(x) { # In R, sample variance is calculated with the var() function. var_samp(x) }) @@ -1467,6 +1467,7 @@ setMethod("pmod", signature(y = "Column"), #' @name approxCountDistinct #' @return the approximate number of distinct items in a group. #' @export +#' @examples \dontrun{approxCountDistinct(df$c, 0.02)} setMethod("approxCountDistinct", signature(x = "Column"), function(x, rsd = 0.05) { @@ -1481,14 +1482,16 @@ setMethod("approxCountDistinct", #' @name countDistinct #' @return the number of distinct items in a group. #' @export +#' @examples \dontrun{countDistinct(df$c)} setMethod("countDistinct", signature(x = "Column"), function(x, ...) { - jcol <- lapply(list(...), function (x) { + jcols <- lapply(list(...), function (x) { + stopifnot(class(x) == "Column") x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - jcol) + jcols) column(jc) }) @@ -1501,10 +1504,14 @@ setMethod("countDistinct", #' @rdname concat #' @name concat #' @export +#' @examples \dontrun{concat(df$strings, df$strings2)} setMethod("concat", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) column(jc) }) @@ -1518,11 +1525,15 @@ setMethod("concat", #' @rdname greatest #' @name greatest #' @export +#' @examples \dontrun{greatest(df$c, df$d)} setMethod("greatest", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) column(jc) }) @@ -1530,17 +1541,21 @@ setMethod("greatest", #' least #' #' Returns the least value of the list of column names, skipping null values. -#' This function takes at least 2 parameters. It will return null iff all parameters are null. +#' This function takes at least 2 parameters. It will return null if all parameters are null. #' #' @family normal_funcs #' @rdname least #' @name least #' @export +#' @examples \dontrun{least(df$c, df$d)} setMethod("least", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function(x) { x@jc }) + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) column(jc) }) @@ -1549,11 +1564,10 @@ setMethod("least", #' #' Computes the ceiling of the given value. #' -#' @family math_funcs #' @rdname ceil -#' @name ceil -#' @aliases ceil +#' @name ceiling #' @export +#' @examples \dontrun{ceiling(df$c)} setMethod("ceiling", signature(x = "Column"), function(x) { @@ -1564,11 +1578,10 @@ setMethod("ceiling", #' #' Computes the signum of the given value. #' -#' @family math_funcs #' @rdname signum -#' @name signum -#' @aliases signum +#' @name sign #' @export +#' @examples \dontrun{sign(df$c)} setMethod("sign", signature(x = "Column"), function(x) { signum(x) @@ -1578,11 +1591,10 @@ setMethod("sign", signature(x = "Column"), #' #' Aggregate function: returns the number of distinct items in a group. #' -#' @family agg_funcs #' @rdname countDistinct -#' @name countDistinct -#' @aliases countDistinct +#' @name n_distinct #' @export +#' @examples \dontrun{n_distinct(df$c)} setMethod("n_distinct", signature(x = "Column"), function(x, ...) { countDistinct(x, ...) @@ -1592,11 +1604,10 @@ setMethod("n_distinct", signature(x = "Column"), #' #' Aggregate function: returns the number of items in a group. #' -#' @family agg_funcs #' @rdname count -#' @name count -#' @aliases count +#' @name n #' @export +#' @examples \dontrun{n(df$c)} setMethod("n", signature(x = "Column"), function(x) { count(x) @@ -1617,6 +1628,7 @@ setMethod("n", signature(x = "Column"), #' @rdname date_format #' @name date_format #' @export +#' @examples \dontrun{date_format(df$t, 'MM/dd/yyy')} setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) @@ -1631,6 +1643,7 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' @rdname from_utc_timestamp #' @name from_utc_timestamp #' @export +#' @examples \dontrun{from_utc_timestamp(df$t, 'PST')} setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) @@ -1649,6 +1662,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' @rdname instr #' @name instr #' @export +#' @examples \dontrun{instr(df$c, 'b')} setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) @@ -1663,13 +1677,18 @@ setMethod("instr", signature(y = "Column", x = "character"), #' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first #' Sunday after 2015-07-27. #' -#' Day of the week parameter is case insensitive, and accepts: +#' Day of the week parameter is case insensitive, and accepts first three or two characters: #' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' #' @family datetime_funcs #' @rdname next_day #' @name next_day #' @export +#' @examples +#'\dontrun{ +#'next_day(df$d, 'Sun') +#'next_day(df$d, 'Sunday') +#'} setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) @@ -1684,6 +1703,7 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' @rdname to_utc_timestamp #' @name to_utc_timestamp #' @export +#' @examples \dontrun{to_utc_timestamp(df$t, 'PST')} setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) @@ -1697,8 +1717,8 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' @name add_months #' @family datetime_funcs #' @rdname add_months -#' @name add_months #' @export +#' @examples \dontrun{add_months(df$d, 1)} setMethod("add_months", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) @@ -1713,6 +1733,7 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' @rdname date_add #' @name date_add #' @export +#' @examples \dontrun{date_add(df$d, 1)} setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) @@ -1727,6 +1748,7 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @rdname date_sub #' @name date_sub #' @export +#' @examples \dontrun{date_sub(df$d, 1)} setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) @@ -1735,16 +1757,19 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' format_number #' -#' Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, +#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places, #' and returns the result as a string column. #' -#' If d is 0, the result has no decimal point or fractional part. -#' If d < 0, the result will be null.' +#' If x is 0, the result has no decimal point or fractional part. +#' If x < 0, the result will be null. #' +#' @param y column to format +#' @param x number of decimal place to format to #' @family string_funcs #' @rdname format_number #' @name format_number #' @export +#' @examples \dontrun{format_number(df$n, 4)} setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1764,6 +1789,7 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' @rdname sha2 #' @name sha2 #' @export +#' @examples \dontrun{sha2(df$c, 256)} setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) @@ -1779,6 +1805,7 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' @rdname shiftLeft #' @name shiftLeft #' @export +#' @examples \dontrun{shiftLeft(df$c, 1)} setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1796,6 +1823,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' @rdname shiftRight #' @name shiftRight #' @export +#' @examples \dontrun{shiftRight(df$c, 1)} setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1813,6 +1841,7 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' @rdname shiftRightUnsigned #' @name shiftRightUnsigned #' @export +#' @examples \dontrun{shiftRightUnsigned(df$c, 1)} setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1830,6 +1859,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @rdname concat_ws #' @name concat_ws #' @export +#' @examples \dontrun{concat_ws('-', df$s, df$d)} setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { jcols <- lapply(list(x, ...), function(x) { x@jc }) @@ -1845,6 +1875,7 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @rdname conv #' @name conv #' @export +#' @examples \dontrun{conv(df$n, 2, 16)} setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { fromBase <- as.integer(fromBase) @@ -1864,6 +1895,7 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' @rdname expr #' @name expr #' @export +#' @examples \dontrun{expr('length(name)')} setMethod("expr", signature(x = "character"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) @@ -1878,6 +1910,7 @@ setMethod("expr", signature(x = "character"), #' @rdname format_string #' @name format_string #' @export +#' @examples \dontrun{format_string('%d %s', df$a, df$b)} setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { jcols <- lapply(list(x, ...), function(arg) { arg@jc }) @@ -1897,6 +1930,11 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' @rdname from_unixtime #' @name from_unixtime #' @export +#' @examples +#'\dontrun{ +#'from_unixtime(df$t) +#'from_unixtime(df$t, 'yyyy/MM/dd HH') +#'} setMethod("from_unixtime", signature(x = "Column"), function(x, format = "yyyy-MM-dd HH:mm:ss") { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1915,6 +1953,7 @@ setMethod("from_unixtime", signature(x = "Column"), #' @rdname locate #' @name locate #' @export +#' @examples \dontrun{locate('b', df$c, 1)} setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 0) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1931,6 +1970,7 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @rdname lpad #' @name lpad #' @export +#' @examples \dontrun{lpad(df$c, 6, '#')} setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1947,12 +1987,13 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @rdname rand #' @name rand #' @export +#' @examples \dontrun{rand()} setMethod("rand", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "rand") column(jc) }) -#' @family normal_funcs + #' @rdname rand #' @name rand #' @export @@ -1970,12 +2011,13 @@ setMethod("rand", signature(seed = "numeric"), #' @rdname randn #' @name randn #' @export +#' @examples \dontrun{randn()} setMethod("randn", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "randn") column(jc) }) -#' @family normal_funcs + #' @rdname randn #' @name randn #' @export @@ -1993,6 +2035,7 @@ setMethod("randn", signature(seed = "numeric"), #' @rdname regexp_extract #' @name regexp_extract #' @export +#' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), function(x, pattern, idx) { @@ -2010,6 +2053,7 @@ setMethod("regexp_extract", #' @rdname regexp_replace #' @name regexp_replace #' @export +#' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), function(x, pattern, replacement) { @@ -2027,6 +2071,7 @@ setMethod("regexp_replace", #' @rdname rpad #' @name rpad #' @export +#' @examples \dontrun{rpad(df$c, 6, '#')} setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -2040,12 +2085,17 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' Returns the substring from string str before count occurrences of the delimiter delim. #' If count is positive, everything the left of the final delimiter (counting from left) is #' returned. If count is negative, every to the right of the final delimiter (counting from the -#' right) is returned. substring <- index performs a case-sensitive match when searching for delim. +#' right) is returned. substring_index performs a case-sensitive match when searching for delim. #' #' @family string_funcs #' @rdname substring_index #' @name substring_index #' @export +#' @examples +#'\dontrun{ +#'substring_index(df$c, '.', 2) +#'substring_index(df$c, '.', -1) +#'} setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), function(x, delim, count) { @@ -2066,6 +2116,7 @@ setMethod("substring_index", #' @rdname translate #' @name translate #' @export +#' @examples \dontrun{translate(df$c, 'rnlt', '123')} setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), function(x, matchingString, replaceString) { @@ -2082,12 +2133,18 @@ setMethod("translate", #' @rdname unix_timestamp #' @name unix_timestamp #' @export +#' @examples +#'\dontrun{ +#'unix_timestamp() +#'unix_timestamp(df$t) +#'unix_timestamp(df$t, 'yyyy-MM-dd HH') +#'} setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") column(jc) }) -#' @family datetime_funcs + #' @rdname unix_timestamp #' @name unix_timestamp #' @export @@ -2096,7 +2153,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) column(jc) }) -#' @family datetime_funcs + #' @rdname unix_timestamp #' @name unix_timestamp #' @export @@ -2113,7 +2170,9 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' @family normal_funcs #' @rdname when #' @name when +#' @seealso \link{ifelse} #' @export +#' @examples \dontrun{when(df$age == 2, df$age + 1)} setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { condition <- condition@jc @@ -2130,7 +2189,9 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @family normal_funcs #' @rdname ifelse #' @name ifelse +#' @seealso \link{when} #' @export +#' @examples \dontrun{ifelse(df$a > 1 & df$b > 2, 0, 1)} setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), function(test, yes, no) { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 0fbe0658265b..899fc3b97738 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -880,14 +880,15 @@ test_that("column functions", { expect_equal(collect(df3)[[2, 1]], FALSE) expect_equal(collect(df3)[[3, 1]], TRUE) - expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) + df4 <- select(df, countDistinct(df$age, df$name)) + expect_equal(collect(df4)[[1, 1]], 2) + expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6) - expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25) - df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) - expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") + df5 <- createDataFrame(sqlContext, list(list(a = "010101"))) + expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") # Test array_contains() and sort_array() df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) From c793d2d9a1ccc203fc103eb0636958fe8d71f471 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sat, 28 Nov 2015 21:16:21 -0800 Subject: [PATCH 1582/1824] [SPARK-9319][SPARKR] Add support for setting column names, types Add support for for colnames, colnames<-, coltypes<- Also added tests for names, names<- which have no test previously. I merged with PR 8984 (coltypes). Clicked the wrong thing, crewed up the PR. Recreated it here. Was #9218 shivaram sun-rui Author: felixcheung Closes #9654 from felixcheung/colnamescoltypes. --- R/pkg/NAMESPACE | 6 +- R/pkg/R/DataFrame.R | 166 ++++++++++++++++++++++--------- R/pkg/R/generics.R | 20 +++- R/pkg/R/types.R | 8 ++ R/pkg/inst/tests/test_sparkSQL.R | 40 +++++++- 5 files changed, 185 insertions(+), 55 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5d04dd6acaab..43e5e0119e7f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -27,7 +27,10 @@ exportMethods("arrange", "attach", "cache", "collect", + "colnames", + "colnames<-", "coltypes", + "coltypes<-", "columns", "count", "cov", @@ -56,6 +59,7 @@ exportMethods("arrange", "mutate", "na.omit", "names", + "names<-", "ncol", "nrow", "orderBy", @@ -276,4 +280,4 @@ export("structField", "structType", "structType.jobj", "structType.structField", - "print.structType") \ No newline at end of file + "print.structType") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 8a13e7a36766..f89e2682d9e2 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -254,6 +254,7 @@ setMethod("dtypes", #' @family DataFrame functions #' @rdname columns #' @name columns + #' @export #' @examples #'\dontrun{ @@ -262,6 +263,7 @@ setMethod("dtypes", #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) #' columns(df) +#' colnames(df) #'} setMethod("columns", signature(x = "DataFrame"), @@ -290,6 +292,121 @@ setMethod("names<-", } }) +#' @rdname columns +#' @name colnames +setMethod("colnames", + signature(x = "DataFrame"), + function(x) { + columns(x) + }) + +#' @rdname columns +#' @name colnames<- +setMethod("colnames<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) + dataFrame(sdf) + }) + +#' coltypes +#' +#' Get column types of a DataFrame +#' +#' @param x A SparkSQL DataFrame +#' @return value A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @name coltypes +#' @family DataFrame functions +#' @export +#' @examples +#'\dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#'} +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) + +#' coltypes +#' +#' Set the column types of a DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @param value A character vector with the target column types for the given +#' DataFrame. Column types can be one of integer, numeric/double, character, logical, or NA +#' to keep that column as-is. +#' @rdname coltypes +#' @name coltypes<- +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' coltypes(df) <- c("character", "integer") +#' coltypes(df) <- c(NA, "numeric") +#'} +setMethod("coltypes<-", + signature(x = "DataFrame", value = "character"), + function(x, value) { + cols <- columns(x) + ncols <- length(cols) + if (length(value) == 0) { + stop("Cannot set types of an empty DataFrame with no Column") + } + if (length(value) != ncols) { + stop("Length of type vector should match the number of columns for DataFrame") + } + newCols <- lapply(seq_len(ncols), function(i) { + col <- getColumn(x, cols[i]) + if (!is.na(value[i])) { + stype <- rToSQLTypes[[value[i]]] + if (is.null(stype)) { + stop("Only atomic type is supported for column types") + } + cast(col, stype) + } else { + col + } + }) + nx <- select(x, newCols) + dataFrame(nx@sdf) + }) + #' Register Temporary Table #' #' Registers a DataFrame as a Temporary Table in the SQLContext @@ -2102,52 +2219,3 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) - -#' Returns the column types of a DataFrame. -#' -#' @name coltypes -#' @title Get column types of a DataFrame -#' @family dataframe_funcs -#' @param x (DataFrame) -#' @return value (character) A character vector with the column types of the given DataFrame -#' @rdname coltypes -#' @examples \dontrun{ -#' irisDF <- createDataFrame(sqlContext, iris) -#' coltypes(irisDF) -#' } -setMethod("coltypes", - signature(x = "DataFrame"), - function(x) { - # Get the data types of the DataFrame by invoking dtypes() function - types <- sapply(dtypes(x), function(x) {x[[2]]}) - - # Map Spark data types into R's data types using DATA_TYPES environment - rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { - - # Check for primitive types - type <- PRIMITIVE_TYPES[[x]] - - if (is.null(type)) { - # Check for complex types - for (t in names(COMPLEX_TYPES)) { - if (substring(x, 1, nchar(t)) == t) { - type <- COMPLEX_TYPES[[t]] - break - } - } - - if (is.null(type)) { - stop(paste("Unsupported data type: ", x)) - } - } - type - }) - - # Find which types don't have mapping to R - naIndices <- which(is.na(rTypes)) - - # Assign the original scala data types to the unmatched ones - rTypes[naIndices] <- types[naIndices] - - rTypes - }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0c305441e043..711ce38f9e10 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -385,6 +385,22 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") }) #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) +#' @rdname columns +#' @export +setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) + +#' @rdname columns +#' @export +setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) + +#' @rdname coltypes +#' @export +setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) + #' @rdname schema #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) @@ -1081,7 +1097,3 @@ setGeneric("attach") #' @rdname with #' @export setGeneric("with") - -#' @rdname coltypes -#' @export -setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index 1828c23ab0f6..dae4fe858bdb 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -41,3 +41,11 @@ COMPLEX_TYPES <- list( # The full list of data types. DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) + +# An environment for mapping R to Scala, names are R types and values are Scala types. +rToSQLTypes <- as.environment(list( + "integer" = "integer", # in R, integer is 32bit + "numeric" = "double", # in R, numeric == double which is 64bit + "double" = "double", + "character" = "string", + "logical" = "boolean")) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 899fc3b97738..d3b2f20bf81c 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -622,6 +622,26 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form expect_equal(testNames[2], "name") }) +test_that("names() colnames() set the column names", { + df <- jsonFile(sqlContext, jsonPath) + names(df) <- c("col1", "col2") + expect_equal(colnames(df)[2], "col2") + + colnames(df) <- c("col3", "col4") + expect_equal(names(df)[1], "col3") + + # Test base::colnames base::names + m2 <- cbind(1, 1:4) + expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2")) + colnames(m2) <- c("x","Y") + expect_equal(colnames(m2), c("x", "Y")) + + z <- list(a = 1, b = "c", c = 1:3) + expect_equal(names(z)[3], "c") + names(z)[3] <- "c2" + expect_equal(names(z)[3], "c2") +}) + test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) @@ -1617,7 +1637,7 @@ test_that("with() on a DataFrame", { expect_equal(nrow(sum2), 35) }) -test_that("Method coltypes() to get R's data types of a DataFrame", { +test_that("Method coltypes() to get and set R's data types of a DataFrame", { expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) data <- data.frame(c1=c(1,2,3), @@ -1636,6 +1656,24 @@ test_that("Method coltypes() to get R's data types of a DataFrame", { x <- createDataFrame(sqlContext, list(list(as.environment( list("a"="b", "c"="d", "e"="f"))))) expect_equal(coltypes(x), "map") + + df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) + + df1 <- select(df, cast(df$age, "integer")) + coltypes(df) <- c("character", "integer") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"))) + value <- collect(df[, 2])[[3, 1]] + expect_equal(value, collect(df1)[[3, 1]]) + expect_equal(value, 22) + + coltypes(df) <- c(NA, "numeric") + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"))) + + expect_error(coltypes(df) <- c("character"), + "Length of type vector should match the number of columns for DataFrame") + expect_error(coltypes(df) <- c("environment", "list"), + "Only atomic type is supported for column types") }) unlink(parquetPath) From cc7a1bc9370b163f51230e5ca4be612d133a5086 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sun, 29 Nov 2015 11:08:26 -0800 Subject: [PATCH 1583/1824] [SPARK-11781][SPARKR] SparkR has problem in inferring type of raw type. Author: Sun Rui Closes #9769 from sun-rui/SPARK-11781. --- R/pkg/R/DataFrame.R | 34 ++++++++++++++++------------- R/pkg/R/SQLContext.R | 2 +- R/pkg/R/types.R | 37 ++++++++++++++++++-------------- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++++ 4 files changed, 47 insertions(+), 32 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f89e2682d9e2..a82ded9c51fa 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -793,8 +793,8 @@ setMethod("dim", setMethod("collect", signature(x = "DataFrame"), function(x, stringsAsFactors = FALSE) { - names <- columns(x) - ncol <- length(names) + dtypes <- dtypes(x) + ncol <- length(dtypes) if (ncol <= 0) { # empty data.frame with 0 columns and 0 rows data.frame() @@ -817,25 +817,29 @@ setMethod("collect", # data of complex type can be held. But getting a cell from a column # of list type returns a list instead of a vector. So for columns of # non-complex type, append them as vector. + # + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. col <- listCols[[colIndex]] + colName <- dtypes[[colIndex]][[1]] if (length(col) <= 0) { - df[[names[colIndex]]] <- col + df[[colName]] <- col } else { - # TODO: more robust check on column of primitive types - vec <- do.call(c, col) - if (class(vec) != "list") { - df[[names[colIndex]]] <- vec + colType <- dtypes[[colIndex]][[2]] + # Note that "binary" columns behave like complex types. + if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { + vec <- do.call(c, col) + stopifnot(class(vec) != "list") + df[[colName]] <- vec } else { - # For columns of complex type, be careful to access them. - # Get a column of complex type returns a list. - # Get a cell from a column of complex type returns a list instead of a vector. - df[[names[colIndex]]] <- col - } + df[[colName]] <- col + } + } } + df } - df - } - }) + }) #' Limit #' diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index a62b25fde926..85541c8e2244 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -63,7 +63,7 @@ infer_type <- function(x) { }) type <- Reduce(paste0, type) type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">") - } else if (length(x) > 1) { + } else if (length(x) > 1 && type != "binary") { paste0("array<", infer_type(x[[1]]), ">") } else { type diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index dae4fe858bdb..1f06af7e904f 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -19,25 +19,30 @@ # values are equivalent R types. This is stored in an environment to allow for # more efficient look up (environments use hashmaps). PRIMITIVE_TYPES <- as.environment(list( - "byte"="integer", - "tinyint"="integer", - "smallint"="integer", - "integer"="integer", - "bigint"="numeric", - "float"="numeric", - "double"="numeric", - "decimal"="numeric", - "string"="character", - "binary"="raw", - "boolean"="logical", - "timestamp"="POSIXct", - "date"="Date")) + "tinyint" = "integer", + "smallint" = "integer", + "int" = "integer", + "bigint" = "numeric", + "float" = "numeric", + "double" = "numeric", + "decimal" = "numeric", + "string" = "character", + "binary" = "raw", + "boolean" = "logical", + "timestamp" = "POSIXct", + "date" = "Date", + # following types are not SQL types returned by dtypes(). They are listed here for usage + # by checkType() in schema.R. + # TODO: refactor checkType() in schema.R. + "byte" = "integer", + "integer" = "integer" + )) # The complex data types. These do not have any direct mapping to R's types. COMPLEX_TYPES <- list( - "map"=NA, - "array"=NA, - "struct"=NA) + "map" = NA, + "array" = NA, + "struct" = NA) # The full list of data types. DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d3b2f20bf81c..92ec82096c6d 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -72,6 +72,8 @@ test_that("infer types and check types", { expect_equal(infer_type(e), "map") expect_error(checkType("map"), "Key type in a map must be string or character") + + expect_equal(infer_type(as.raw(c(1, 2, 3))), "binary") }) test_that("structType and structField", { @@ -250,6 +252,10 @@ test_that("create DataFrame from list or data.frame", { mtcarsdf <- createDataFrame(sqlContext, mtcars) expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(sqlContext, list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) }) test_that("create DataFrame with different data types", { From 3d28081e53698ed77e93c04299957c02bcaba9bf Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 29 Nov 2015 14:13:11 -0800 Subject: [PATCH 1584/1824] [SPARK-12024][SQL] More efficient multi-column counting. In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null. This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path. cc yhuai Author: Herman van Hovell Closes #10015 from hvanhovell/SPARK-12024. --- .../expressions/aggregate/Count.scala | 21 ++-------- .../expressions/conditionalExpressions.scala | 27 ------------- .../sql/catalyst/optimizer/Optimizer.scala | 14 ++++--- .../ConditionalExpressionSuite.scala | 14 ------- .../spark/sql/execution/aggregate/utils.scala | 39 +++++++++---------- .../spark/sql/expressions/WindowSpec.scala | 4 +- 6 files changed, 33 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 09a1da9200df..441f52ab5ca5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -case class Count(child: Expression) extends DeclarativeAggregate { - override def children: Seq[Expression] = child :: Nil +case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false @@ -30,7 +29,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = LongType // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) private lazy val count = AttributeReference("count", LongType)() @@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions = Seq( - /* count = */ If(IsNull(child), count, count + 1L) + /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L) ) override lazy val mergeExpressions = Seq( @@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate { } object Count { - def apply(children: Seq[Expression]): Count = { - // This is used to deal with COUNT DISTINCT. When we have multiple - // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row). - // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any - // null in the arguments, we will not count that row. So, we use DropAnyNull at here - // to return a null when any field of the created STRUCT is null. - val child = if (children.size > 1) { - DropAnyNull(CreateStruct(children)) - } else { - children.head - } - Count(child) - } + def apply(child: Expression): Count = Count(child :: Nil) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 694a2a7c54a9..40b1eec63e55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression { } } -/** Operator that drops a row when it contains any nulls. */ -case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StructType) - - protected override def nullSafeEval(input: Any): InternalRow = { - val row = input.asInstanceOf[InternalRow] - if (row.anyNull) { - null - } else { - row - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.anyNull()) { - ${ev.isNull} = true; - } else { - ${ev.value} = $eval; - } - """ - }) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2901d8f2efdd..06d14fcf8b9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { + def nonNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => false + case _ => true + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ AggregateExpression(Count(Literal(null, _)), _, _) => + case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) @@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable => + case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. AggregateExpression(Count(Literal(1)), mode, isDistinct = false) // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter { - case Literal(null, _) => false - case _ => true - } + val newChildren = children.filter(nonNullLiteral) if (newChildren.length == 0) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index c1e3c17b8710..0df673bb9fa0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } - - test("function dropAnyNull") { - val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1)))) - val a = create_row("a", "q") - val nullStr: String = null - checkEvaluation(drop, a, a) - checkEvaluation(drop, null, create_row("b", nullStr)) - checkEvaluation(drop, null, create_row(nullStr, nullStr)) - - val row = 'r.struct( - StructField("a", StringType, false), - StructField("b", StringType, true)).at(0) - checkEvaluation(DropAnyNull(row), null, create_row(null)) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index a70e41436c7a..76b938cdb694 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -146,20 +146,16 @@ object Utils { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one - // DISTINCT aggregate function, all of those functions will have the same column expression. + // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is // disallowed because those two distinct aggregates have different column expressions. - val distinctColumnExpression: Expression = { - val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children - assert(allDistinctColumnExpressions.length == 1) - allDistinctColumnExpressions.head - } - val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match { + val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctColumnExpressions = distinctColumnExpressions.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute + val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. @@ -170,10 +166,11 @@ object Utils { // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. - val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression + val partialAggregateGroupingExpressions = + groupingExpressions ++ namedDistinctColumnExpressions val partialAggregateResult = groupingAttributes ++ - Seq(distinctColumnAttribute) ++ + distinctColumnAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) if (usesTungstenAggregate) { TungstenAggregate( @@ -208,28 +205,28 @@ object Utils { partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialMergeAggregateResult = groupingAttributes ++ - Seq(distinctColumnAttribute) ++ + distinctColumnAttributes ++ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) if (usesTungstenAggregate) { TungstenAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } else { SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes :+ distinctColumnAttribute, + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } @@ -244,14 +241,16 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } + val distinctColumnAttributeLookup = + distinctColumnExpressions.zip(distinctColumnAttributes).toMap val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. case agg @ AggregateExpression(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction.transformDown { - case expr if expr == distinctColumnExpression => distinctColumnAttribute - }.asInstanceOf[AggregateFunction] + val rewrittenAggregateFunction = aggregateFunction + .transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, @@ -270,7 +269,7 @@ object Utils { nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = resultExpressions, child = partialMergeAggregate) } else { @@ -281,7 +280,7 @@ object Utils { nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = resultExpressions, child = partialMergeAggregate) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index fc873c04f88f..893e800a6143 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -152,8 +152,8 @@ class WindowSpec private[sql]( case Sum(child) => WindowExpression( UnresolvedWindowFunction("sum", child :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), + case Count(children) => WindowExpression( + UnresolvedWindowFunction("count", children), WindowSpecDefinition(partitionSpec, orderSpec, frame)) case First(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF first_value From 0ddfe7868948e302858a2b03b50762eaefbeb53e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 29 Nov 2015 19:02:15 -0800 Subject: [PATCH 1585/1824] [SPARK-12039] [SQL] Ignore HiveSparkSubmitSuite's "SPARK-9757 Persist Parquet relation with decimal column". https://issues.apache.org/jira/browse/SPARK-12039 Since it is pretty flaky in hadoop 1 tests, we can disable it while we are investigating the cause. Author: Yin Huai Closes #10035 from yhuai/SPARK-12039-ignore. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 24a3afee148c..92962193311d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -101,7 +101,7 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - test("SPARK-9757 Persist Parquet relation with decimal column") { + ignore("SPARK-9757 Persist Parquet relation with decimal column") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SPARK_9757.getClass.getName.stripSuffix("$"), From e0749442051d6e29dae4f4cdcb2937c0b015f98f Mon Sep 17 00:00:00 2001 From: toddwan Date: Mon, 30 Nov 2015 09:26:29 +0000 Subject: [PATCH 1586/1824] [SPARK-11859][MESOS] SparkContext accepts invalid Master URLs in the form zk://host:port for a multi-master Mesos cluster using ZooKeeper * According to below doc and validation logic in [SparkSubmit.scala](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L231), master URL for a mesos cluster should always start with `mesos://` http://spark.apache.org/docs/latest/running-on-mesos.html `The Master URLs for Mesos are in the form mesos://host:5050 for a single-master Mesos cluster, or mesos://zk://host:2181 for a multi-master Mesos cluster using ZooKeeper.` * However, [SparkContext.scala](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/SparkContext.scala#L2749) fails the validation and can receive master URL in the form `zk://host:port` * For the master URLs in the form `zk:host:port`, the valid form should be `mesos://zk://host:port` * This PR restrict the validation in `SparkContext.scala`, and now only mesos master URLs prefixed with `mesos://` can be accepted. * This PR also updated corresponding unit test. Author: toddwan Closes #9886 from toddwan/S11859. --- .../scala/org/apache/spark/SparkContext.scala | 16 ++++++++++------ .../SparkContextSchedulerCreationSuite.scala | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b030d3c71dc2..8a62b71c3fa6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2708,15 +2708,14 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) - case mesosUrl @ MESOS_REGEX(_) => + case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) - val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) + new CoarseMesosSchedulerBackend(scheduler, sc, mesosUrl, sc.env.securityManager) } else { - new MesosSchedulerBackend(scheduler, sc, url) + new MesosSchedulerBackend(scheduler, sc, mesosUrl) } scheduler.initialize(backend) (backend, scheduler) @@ -2727,6 +2726,11 @@ object SparkContext extends Logging { scheduler.initialize(backend) (backend, scheduler) + case zkUrl if zkUrl.startsWith("zk://") => + logWarning("Master URL for a multi-master Mesos cluster managed by ZooKeeper should be " + + "in the form mesos://zk://host:port. Current Master URL will stop working in Spark 2.0.") + createTaskScheduler(sc, "mesos://" + zkUrl) + case _ => throw new SparkException("Could not parse Master URL: '" + master + "'") } @@ -2745,8 +2749,8 @@ private object SparkMasterRegex { val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster by mesos:// or zk:// url - val MESOS_REGEX = """(mesos|zk)://.*""".r + // Regular expression for connection to Mesos cluster by mesos:// or mesos://zk:// url + val MESOS_REGEX = """mesos://(.*)""".r // Regular expression for connection to Simr cluster val SIMR_REGEX = """simr://(.*)""".r } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index e5a14a69ef05..d18e0782c039 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -175,6 +175,11 @@ class SparkContextSchedulerCreationSuite } test("mesos with zookeeper") { + testMesos("mesos://zk://localhost:1234,localhost:2345", + classOf[MesosSchedulerBackend], coarse = false) + } + + test("mesos with zookeeper and Master URL starting with zk://") { testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend], coarse = false) } } From 953e8e6dcb32cd88005834e9c3720740e201826c Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 30 Nov 2015 09:30:58 +0000 Subject: [PATCH 1587/1824] [MINOR][BUILD] Changed the comment to reflect the plugin project is there to support SBT pom reader only. Author: Prashant Sharma Closes #10012 from ScrapCodes/minor-build-comment. --- project/project/SparkPluginBuild.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 471d00bd8223..cbb88dc7dd1d 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -19,9 +19,8 @@ import sbt._ import sbt.Keys._ /** - * This plugin project is there to define new scala style rules for spark. This is - * a plugin project so that this gets compiled first and is put on the classpath and - * becomes available for scalastyle sbt plugin. + * This plugin project is there because we use our custom fork of sbt-pom-reader plugin. This is + * a plugin project so that this gets compiled first and is available on the classpath for SBT build. */ object SparkPluginDef extends Build { lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) From 26c3581f17f475fab2f3b5301b8f253ff2fa6438 Mon Sep 17 00:00:00 2001 From: Wieland Hoffmann Date: Mon, 30 Nov 2015 09:32:48 +0000 Subject: [PATCH 1588/1824] [DOC] Explicitly state that top maintains the order of elements Top is implemented in terms of takeOrdered, which already maintains the order, so top should, too. Author: Wieland Hoffmann Closes #10013 from mineo/top-order. --- .../main/scala/org/apache/spark/api/java/JavaRDDLike.scala | 4 ++-- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 871be0b1f39e..1e9d4f1803a8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -556,7 +556,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD as defined by - * the specified Comparator[T]. + * the specified Comparator[T] and maintains the order. * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements @@ -567,7 +567,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the top k (largest) elements from this RDD using the - * natural ordering for T. + * natural ordering for T and maintains the order. * @param num k, the number of top elements to return * @return an array of top elements */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2aeb5eeaad32..8b3731d93578 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1327,7 +1327,8 @@ abstract class RDD[T: ClassTag]( /** * Returns the top k (largest) elements from this RDD as defined by the specified - * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: + * implicit Ordering[T] and maintains the ordering. This does the opposite of + * [[takeOrdered]]. For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) * // returns Array(12) From bf0e85a70a54a2d7fd6804b6bd00c63c20e2bb00 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 30 Nov 2015 10:11:27 +0000 Subject: [PATCH 1589/1824] [SPARK-12023][BUILD] Fix warnings while packaging spark with maven. this is a trivial fix, discussed [here](http://stackoverflow.com/questions/28500401/maven-assembly-plugin-warning-the-assembly-descriptor-contains-a-filesystem-roo/). Author: Prashant Sharma Closes #10014 from ScrapCodes/assembly-warning. --- assembly/src/main/assembly/assembly.xml | 8 ++++---- external/mqtt/src/main/assembly/assembly.xml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 711156337b7c..009d4b92f406 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -32,7 +32,7 @@ ${project.parent.basedir}/core/src/main/resources/org/apache/spark/ui/static/ - /ui-resources/org/apache/spark/ui/static + ui-resources/org/apache/spark/ui/static **/* @@ -41,7 +41,7 @@ ${project.parent.basedir}/sbin/ - /sbin + sbin **/* @@ -50,7 +50,7 @@ ${project.parent.basedir}/bin/ - /bin + bin **/* @@ -59,7 +59,7 @@ ${project.parent.basedir}/assembly/target/${spark.jar.dir} - / + ${spark.jar.basename} diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml index ecab5b360eb3..c110b01b34e1 100644 --- a/external/mqtt/src/main/assembly/assembly.xml +++ b/external/mqtt/src/main/assembly/assembly.xml @@ -24,7 +24,7 @@ ${project.build.directory}/scala-${scala.binary.version}/test-classes - / + From 2db4662fe2d72749c06ad5f11f641a388343f77c Mon Sep 17 00:00:00 2001 From: CK50 Date: Mon, 30 Nov 2015 20:08:49 +0800 Subject: [PATCH 1590/1824] [SPARK-11989][SQL] Only use commit in JDBC data source if the underlying database supports transactions Fixes [SPARK-11989](https://issues.apache.org/jira/browse/SPARK-11989) Author: CK50 Author: Christian Kurz Closes #9973 from CK50/branch-1.6_non-transactional. (cherry picked from commit a589736a1b237ef2f3bd59fbaeefe143ddcc8f4e) Signed-off-by: Reynold Xin --- .../datasources/jdbc/JdbcUtils.scala | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 7375a5c09123..252f1cfd5d9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -21,6 +21,7 @@ import java.sql.{Connection, PreparedStatement} import java.util.Properties import scala.util.Try +import scala.util.control.NonFatal import org.apache.spark.Logging import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects} @@ -125,8 +126,19 @@ object JdbcUtils extends Logging { dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() var committed = false + val supportsTransactions = try { + conn.getMetaData().supportsDataManipulationTransactionsOnly() || + conn.getMetaData().supportsDataDefinitionAndDataManipulationTransactions() + } catch { + case NonFatal(e) => + logWarning("Exception while detecting transaction support", e) + true + } + try { - conn.setAutoCommit(false) // Everything in the same db transaction. + if (supportsTransactions) { + conn.setAutoCommit(false) // Everything in the same db transaction. + } val stmt = insertStatement(conn, table, rddSchema) try { var rowCount = 0 @@ -175,14 +187,18 @@ object JdbcUtils extends Logging { } finally { stmt.close() } - conn.commit() + if (supportsTransactions) { + conn.commit() + } committed = true } finally { if (!committed) { // The stage must fail. We got here through an exception path, so // let the exception through unless rollback() or close() want to // tell the user about another problem. - conn.rollback() + if (supportsTransactions) { + conn.rollback() + } conn.close() } else { // The stage must succeed. We cannot propagate any exception close() might throw. From 17275fa99c670537c52843df405279a52b5c9594 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 10:32:13 -0800 Subject: [PATCH 1591/1824] [SPARK-11700] [SQL] Remove thread local SQLContext in SparkPlan In 1.6, we introduce a public API to have a SQLContext for current thread, SparkPlan should use that. Author: Davies Liu Closes #9990 from davies/leak_context. --- .../scala/org/apache/spark/sql/SQLContext.scala | 10 +++++----- .../spark/sql/execution/QueryExecution.scala | 3 +-- .../org/apache/spark/sql/execution/SparkPlan.scala | 14 ++++---------- .../apache/spark/sql/MultiSQLContextsSuite.scala | 2 +- .../sql/execution/ExchangeCoordinatorSuite.scala | 2 +- .../sql/execution/RowFormatConvertersSuite.scala | 4 ++-- 6 files changed, 14 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1c2ac5f6f11b..8d2783952532 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,7 +26,6 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SparkException, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD @@ -45,9 +44,10 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.sql.util.ExecutionListenerManager +import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException} /** * The entry point for working with structured data (rows and columns) in Spark. Allows the @@ -401,7 +401,7 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) @@ -417,7 +417,7 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SparkPlan.currentContext.set(self) + SQLContext.setActive(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes DataFrame(self, LocalRelation.fromProduct(attributeSeq, data)) @@ -1334,7 +1334,7 @@ object SQLContext { activeContext.remove() } - private[sql] def getActiveContextOption(): Option[SQLContext] = { + private[sql] def getActive(): Option[SQLContext] = { Option(activeContext.get()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5da5aea17e25..107570f9dbcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -42,9 +42,8 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) - // TODO: Don't just pick the first one... lazy val sparkPlan: SparkPlan = { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) sqlContext.planner.plan(optimizedPlan).next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 534a3bcb8364..507641ff8263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,21 +23,15 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric} import org.apache.spark.sql.types.DataType -object SparkPlan { - protected[sql] val currentContext = new ThreadLocal[SQLContext]() -} - /** * The base class for physical operators. */ @@ -49,7 +43,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SparkPlan.currentContext.get() + protected[spark] final val sqlContext = SQLContext.getActive().get protected def sparkContext = sqlContext.sparkContext @@ -69,7 +63,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Overridden make copy also propogates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) super.makeCopy(newArgs) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 34c5c68fd1c1..162c0b56c6e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -27,7 +27,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { private var sparkConf: SparkConf = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActiveContextOption() + originalActiveSQLContext = SQLContext.getActive() originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index b96d50a70b85..180050bdac00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -30,7 +30,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalInstantiatedSQLContext: Option[SQLContext] = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActiveContextOption() + originalActiveSQLContext = SQLContext.getActive() originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 6876ab0f02b1..13d68a103a22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} import org.apache.spark.sql.catalyst.util.GenericArrayData @@ -94,7 +94,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(sqlContext) + SQLContext.setActive(sqlContext) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) From 8df584b0200402d8b2ce0a8de24f7a760ced8655 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 11:54:18 -0800 Subject: [PATCH 1592/1824] [SPARK-11982] [SQL] improve performance of cartesian product This PR improve the performance of CartesianProduct by caching the result of right plan. After this patch, the query time of TPC-DS Q65 go down to 4 seconds from 28 minutes (420X faster). cc nongli Author: Davies Liu Closes #9969 from davies/improve_cartesian. --- .../unsafe/sort/UnsafeExternalSorter.java | 63 +++++++++++++++ .../unsafe/sort/UnsafeInMemorySorter.java | 7 ++ .../execution/joins/CartesianProduct.scala | 76 +++++++++++++++++-- .../execution/metric/SQLMetricsSuite.scala | 2 +- 4 files changed, 139 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 2e4031267473..5a97f4f11340 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.IOException; import java.util.LinkedList; +import java.util.Queue; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -521,4 +522,66 @@ public long getKeyPrefix() { return upstream.getKeyPrefix(); } } + + /** + * Returns a iterator, which will return the rows in the order as inserted. + * + * It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ + public UnsafeSorterIterator getIterator() throws IOException { + if (spillWriters.isEmpty()) { + assert(inMemSorter != null); + return inMemSorter.getIterator(); + } else { + LinkedList queue = new LinkedList<>(); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + queue.add(spillWriter.getReader(blockManager)); + } + if (inMemSorter != null) { + queue.add(inMemSorter.getIterator()); + } + return new ChainedIterator(queue); + } + } + + /** + * Chain multiple UnsafeSorterIterator together as single one. + */ + class ChainedIterator extends UnsafeSorterIterator { + + private final Queue iterators; + private UnsafeSorterIterator current; + + public ChainedIterator(Queue iterators) { + assert iterators.size() > 0; + this.iterators = iterators; + this.current = iterators.remove(); + } + + @Override + public boolean hasNext() { + while (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); + } + return current.hasNext(); + } + + @Override + public void loadNext() throws IOException { + current.loadNext(); + } + + @Override + public Object getBaseObject() { return current.getBaseObject(); } + + @Override + public long getBaseOffset() { return current.getBaseOffset(); } + + @Override + public int getRecordLength() { return current.getRecordLength(); } + + @Override + public long getKeyPrefix() { return current.getKeyPrefix(); } + } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index dce1f15a2963..c91e88f31bf9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -226,4 +226,11 @@ public SortedIterator getSortedIterator() { sorter.sort(array, 0, pos / 2, sortComparator); return new SortedIterator(pos / 2); } + + /** + * Returns an iterator over record pointers in original order (inserted). + */ + public SortedIterator getIterator() { + return new SortedIterator(pos / 2); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index f467519b802a..fa2bc7672131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -17,16 +17,75 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.rdd.RDD +import org.apache.spark._ +import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + + +/** + * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, + * will be much faster than building the right partition for every row in left RDD, it also + * materialize the right RDD (in case of the right RDD is nondeterministic). + */ +private[spark] +class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) + extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { + + override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { + // We will not sort the rows, so prefixComparator and recordComparator are null. + val sorter = UnsafeExternalSorter.create( + context.taskMemoryManager(), + SparkEnv.get.blockManager, + context, + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes) + + val partition = split.asInstanceOf[CartesianPartition] + for (y <- rdd2.iterator(partition.s2, context)) { + sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) + } + + // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] + def createIter(): Iterator[UnsafeRow] = { + val iter = sorter.getIterator + val unsafeRow = new UnsafeRow + new Iterator[UnsafeRow] { + override def hasNext: Boolean = { + iter.hasNext + } + override def next(): UnsafeRow = { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight, + iter.getRecordLength) + unsafeRow + } + } + } + + val resultIter = + for (x <- rdd1.iterator(partition.s1, context); + y <- createIter()) yield (x, y) + CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( + resultIter, sorter.cleanupResources) + } +} case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def outputsUnsafeRows: Boolean = true + override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), @@ -39,18 +98,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod val leftResults = left.execute().map { row => numLeftRows += 1 - row.copy() + row.asInstanceOf[UnsafeRow] } val rightResults = right.execute().map { row => numRightRows += 1 - row.copy() + row.asInstanceOf[UnsafeRow] } - leftResults.cartesian(rightResults).mapPartitionsInternal { iter => - val joinedRow = new JoinedRow + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + pair.mapPartitionsInternal { iter => + val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) iter.map { r => numOutputRows += 1 - joinedRow(r._1, r._2) + joiner.join(r._1, r._2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index ebfa1eaf3e5b..4f2cad19bfb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -317,7 +317,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { testSparkPlanMetrics(df, 1, Map( 1L -> ("CartesianProduct", Map( "number of left rows" -> 12L, // left needs to be scanned twice - "number of right rows" -> 12L, // right is read 6 times + "number of right rows" -> 4L, // right is read twice "number of output rows" -> 12L))) ) } From f2fbfa444f6e8d27953ec2d1c0b3abd603c963f9 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 30 Nov 2015 13:02:08 -0800 Subject: [PATCH 1593/1824] [MINOR][DOCS] fixed list display in ml-ensembles The list in ml-ensembles.md wasn't properly formatted and, as a result, was looking like this: ![old](http://i.imgur.com/2ZhELLR.png) This PR aims to make it look like this: ![new](http://i.imgur.com/0Xriwd2.png) Author: BenFradet Closes #10025 from BenFradet/ml-ensembles-doc. --- docs/ml-ensembles.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index f6c3c30d5334..14fef76f260f 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -20,6 +20,7 @@ Both use [MLlib decision trees](ml-decision-tree.html) as their base models. Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). In this section, we demonstrate the Pipelines API for ensembles. The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: + * support for ML Pipelines * separation of classification vs. regression * use of DataFrame metadata to distinguish continuous and categorical features From 2c5dee0fb8e4d1734ea3a0f22e0b5bfd2f6dba46 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 30 Nov 2015 13:41:52 -0800 Subject: [PATCH 1594/1824] Revert "[SPARK-11206] Support SQL UI on the history server" This reverts commit cc243a079b1c039d6e7f0b410d1654d94a090e14 / PR #9297 I'm reverting this because it broke SQLListenerMemoryLeakSuite in the master Maven builds. See #9991 for a discussion of why this broke the tests. --- .rat-excludes | 1 - .../org/apache/spark/JavaSparkListener.java | 3 - .../apache/spark/SparkFirehoseListener.java | 4 - .../scheduler/EventLoggingListener.scala | 4 - .../spark/scheduler/SparkListener.scala | 24 +-- .../spark/scheduler/SparkListenerBus.scala | 1 - .../scala/org/apache/spark/ui/SparkUI.scala | 16 +- .../org/apache/spark/util/JsonProtocol.scala | 11 +- ...park.scheduler.SparkHistoryListenerFactory | 1 - .../org/apache/spark/sql/SQLContext.scala | 18 +-- .../spark/sql/execution/SQLExecution.scala | 24 ++- .../spark/sql/execution/SparkPlanInfo.scala | 46 ------ .../sql/execution/metric/SQLMetricInfo.scala | 30 ---- .../sql/execution/metric/SQLMetrics.scala | 56 +++---- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 139 ++++++------------ .../spark/sql/execution/ui/SQLTab.scala | 12 +- .../sql/execution/ui/SparkPlanGraph.scala | 20 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 43 +++--- .../spark/sql/test/SharedSQLContext.scala | 1 - 21 files changed, 135 insertions(+), 327 deletions(-) delete mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala diff --git a/.rat-excludes b/.rat-excludes index 7262c960ed6b..08fba6d351d6 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,5 +82,4 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister -org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index 23bc9a2e8172..fa9acf0a15b8 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,7 +82,4 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } - @Override - public void onOtherEvent(SparkListenerEvent event) { } - } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index e6b24afd88ad..1214d05ba606 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,8 +118,4 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } - @Override - public void onOtherEvent(SparkListenerEvent event) { - onEvent(event); - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index eaa07acc5132..000a021a528c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,10 +207,6 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } - override def onOtherEvent(event: SparkListenerEvent): Unit = { - logEvent(event, flushLogger = true) - } - /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 075a7f13172d..896f1743332f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,19 +22,15 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import com.fasterxml.jackson.annotation.JsonTypeInfo - -import org.apache.spark.{Logging, SparkConf, TaskEndReason} +import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} -import org.apache.spark.ui.SparkUI @DeveloperApi -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") -trait SparkListenerEvent +sealed trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -134,17 +130,6 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent -/** - * Interface for creating history listeners defined in other modules like SQL, which are used to - * rebuild the history UI. - */ -private[spark] trait SparkHistoryListenerFactory { - /** - * Create listeners used to rebuild the history UI. - */ - def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] -} - /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -238,11 +223,6 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } - - /** - * Called when other events like SQL-specific events are posted. - */ - def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 95722a07144e..04afde33f5aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,7 +61,6 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata - case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 8da6884a3853..4608bce202ec 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,13 +17,10 @@ package org.apache.spark.ui -import java.util.{Date, ServiceLoader} - -import scala.collection.JavaConverters._ +import java.util.Date import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -157,16 +154,7 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - val sparkUI = create( - None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) - - val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], - Utils.getContextOrSparkClassLoader).asScala - listenerFactories.foreach { listenerFactory => - val listeners = listenerFactory.createListeners(conf, sparkUI) - listeners.foreach(listenerBus.addListener) - } - sparkUI + create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 7f5d713ec650..c9beeb25e05a 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,21 +19,19 @@ package org.apache.spark.util import java.util.{Properties, UUID} +import org.apache.spark.scheduler.cluster.ExecutorInfo + import scala.collection.JavaConverters._ import scala.collection.Map -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -56,8 +54,6 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats - private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) - /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -100,7 +96,6 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this - case _ => parse(mapper.writeValueAsString(event)) } } @@ -516,8 +511,6 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) - case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) - .asInstanceOf[SparkListenerEvent] } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory deleted file mode 100644 index 507100be9096..000000000000 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory +++ /dev/null @@ -1 +0,0 @@ -org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8d2783952532..9cc65de19180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1263,8 +1263,6 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() - @transient private val sqlListener = new AtomicReference[SQLListener]() - /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1309,10 +1307,6 @@ object SQLContext { Option(instantiatedContext.get()) } - private[sql] def clearSqlListener(): Unit = { - sqlListener.set(null) - } - /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1361,13 +1355,9 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - if (sqlListener.get() == null) { - val listener = new SQLListener(sc.conf) - if (sqlListener.compareAndSet(null, listener)) { - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - } - } - sqlListener.get() + val listener = new SQLListener(sc.conf) + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + listener } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 34971986261c..1422e15549c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,8 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, - SparkListenerSQLExecutionEnd} +import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -46,14 +45,25 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + sqlContext.listener.onExecutionStart( + executionId, + callSite.shortForm, + callSite.longForm, + queryExecution.toString, + SparkPlanGraph(queryExecution.executedPlan), + System.currentTimeMillis()) try { body } finally { - sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. + // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new + // SQL event types to SparkListener since it's a public API, we cannot guarantee that. + // + // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. + // + // The worst case is onExecutionEnd may happen before onJobStart when the listener thread + // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. + sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala deleted file mode 100644 index 486ce34064e4..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.execution.metric.SQLMetricInfo -import org.apache.spark.util.Utils - -/** - * :: DeveloperApi :: - * Stores information about a SQL SparkPlan. - */ -@DeveloperApi -class SparkPlanInfo( - val nodeName: String, - val simpleString: String, - val children: Seq[SparkPlanInfo], - val metrics: Seq[SQLMetricInfo]) - -private[sql] object SparkPlanInfo { - - def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - new SQLMetricInfo(metric.name.getOrElse(key), metric.id, - Utils.getFormattedClassName(metric.param)) - } - val children = plan.children.map(fromSparkPlan) - - new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala deleted file mode 100644 index 2708219ad348..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.metric - -import org.apache.spark.annotation.DeveloperApi - -/** - * :: DeveloperApi :: - * Stores information about a SQL Metric. - */ -@DeveloperApi -class SQLMetricInfo( - val name: String, - val accumulatorId: Long, - val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 6c0f6f8a52dc..1c253e3942e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,39 +104,21 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } -private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) - -private object StaticsLongSQLMetricParam extends LongSQLMetricParam( - (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - }, -1L) - private[sql] object SQLMetrics { private def createLongMetric( sc: SparkContext, name: String, - param: LongSQLMetricParam): LongSQLMetric = { + stringValue: Seq[Long] => String, + initialValue: Long): LongSQLMetric = { + val param = new LongSQLMetricParam(stringValue, initialValue) val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, LongSQLMetricParam) + createLongMetric(sc, name, _.sum.toString, 0L) } /** @@ -144,25 +126,31 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { + val stringValue = (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) - } - - def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { - val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) - val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) - val metricParam = metricParamName match { - case `longSQLMetricParam` => LongSQLMetricParam - case `staticsSQLMetricParam` => StaticsLongSQLMetricParam - } - metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] + createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) + val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index c74ad4040699..e74d6fb396e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} + +import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index e19a1e3e5851..5a072de400b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,34 +19,11 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} -import org.apache.spark.ui.SparkUI - -@DeveloperApi -case class SparkListenerSQLExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - sparkPlanInfo: SparkPlanInfo, - time: Long) - extends SparkListenerEvent - -@DeveloperApi -case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) - extends SparkListenerEvent - -private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { - - override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { - List(new SQLHistoryListener(conf, sparkUI)) - } -} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -141,8 +118,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), - finishTask = false) + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) } } @@ -164,7 +140,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics.accumulatorUpdates(), + taskEnd.taskMetrics, finishTask = true) } @@ -172,12 +148,15 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - protected def updateTaskAccumulatorValues( + private def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - accumulatorUpdates: Map[Long, Any], + metrics: TaskMetrics, finishTask: Boolean): Unit = { + if (metrics == null) { + return + } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -195,9 +174,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = accumulatorUpdates + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = accumulatorUpdates + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -206,7 +185,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, accumulatorUpdates) + attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) } } case None => @@ -214,40 +193,38 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case SparkListenerSQLExecutionStart(executionId, description, details, - physicalPlanDescription, sparkPlanInfo, time) => - val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - val executionUIData = new SQLExecutionUIData( - executionId, - description, - details, - physicalPlanDescription, - physicalPlanGraph, - sqlPlanMetrics.toMap, - time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. - } + def onExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + physicalPlanGraph: SparkPlanGraph, + time: Long): Unit = { + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + + val executionUIData = new SQLExecutionUIData(executionId, description, details, + physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + } + + def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. } } - case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -312,38 +289,6 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } -private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) - extends SQLListener(conf) { - - private var sqlTabAttached = false - - override def onExecutorMetricsUpdate( - executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { - // Do nothing - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskInfo.accumulables.map { acc => - (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) - }.toMap, - finishTask = true) - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case _: SparkListenerSQLExecutionStart => - if (!sqlTabAttached) { - new SQLTab(this, sparkUI) - sqlTabAttached = true - } - super.onOtherEvent(event) - case _ => super.onOtherEvent(event) - } -} - /** * Represent all necessary data for an execution that will be used in Web UI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 4f50b2ecdc8f..9c27944d42fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.ui +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, "SQL") with Logging { + extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { val parent = sparkUI @@ -33,5 +35,13 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { + private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" + + private val nextTabId = new AtomicInteger(0) + + private def nextTabName: String = { + val nextId = nextTabId.getAndIncrement() + if (nextId == 0) "SQL" else s"SQL$nextId" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 7af0ff09c5c6..f1fce5478a3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} /** * A graph used for storing information of an executionPlan of DataFrame. @@ -48,27 +48,27 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { + def apply(plan: SparkPlan): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - planInfo: SparkPlanInfo, + plan: SparkPlan, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, - SQLMetrics.getMetricParam(metric.metricParam)) + val metrics = plan.metrics.toSeq.map { case (key, metric) => + SQLPlanMetric(metric.name.getOrElse(key), metric.id, + metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) nodes += node - val childrenNodes = planInfo.children.map( + val childrenNodes = plan.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) for (child <- childrenNodes) { edges += SparkPlanGraphEdge(child.id, node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4f2cad19bfb6..82867ab4967b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,7 +26,6 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -83,8 +82,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( - df.queryExecution.executedPlan)).nodes.filter { node => + val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index f93d081d0c30..c15aac775096 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,10 +21,10 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.test.SharedSQLContext class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { @@ -82,8 +82,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) - .nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -91,13 +90,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) val executionUIData = listener.executionIdToData(0) @@ -207,8 +206,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -221,20 +219,19 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -251,13 +248,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -274,8 +271,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -292,20 +288,19 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onOtherEvent(SparkListenerSQLExecutionStart( + listener.onExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onOtherEvent(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e7b376548787..963d10eed62e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,7 +42,6 @@ trait SharedSQLContext extends SQLTestUtils { * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { - SQLContext.clearSqlListener() if (_ctx == null) { _ctx = new TestSQLContext } From a8ceec5e8c1572dd3d74783c06c78b7ca0b8a7ce Mon Sep 17 00:00:00 2001 From: Teng Qiu Date: Tue, 1 Dec 2015 07:27:32 +0900 Subject: [PATCH 1595/1824] [SPARK-12053][CORE] EventLoggingListener.getLogPath needs 4 parameters ```EventLoggingListener.getLogPath``` needs 4 input arguments: https://github.com/apache/spark/blob/v1.6.0-preview2/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala#L276-L280 the 3rd parameter should be appAttemptId, 4th parameter is codec... Author: Teng Qiu Closes #10044 from chutium/SPARK-12053. --- core/src/main/scala/org/apache/spark/deploy/master/Master.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9952c97dbdff..1355e1ad1b52 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -934,7 +934,7 @@ private[deploy] class Master( } val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, app.desc.eventLogCodec) + eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + EventLoggingListener.IN_PROGRESS)) From e232720a65dfb9ae6135cbb7674e35eddd88d625 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 30 Nov 2015 14:56:51 -0800 Subject: [PATCH 1596/1824] [SPARK-11689][ML] Add user guide and example code for LDA under spark.ml jira: https://issues.apache.org/jira/browse/SPARK-11689 Add simple user guide for LDA under spark.ml and example code under examples/. Use include_example to include example code in the user guide markdown. Check SPARK-11606 for instructions. Original PR is reverted due to document build error. https://github.com/apache/spark/pull/9722 mengxr feynmanliang yinxusen Sorry for the troubling. Author: Yuhao Yang Closes #9974 from hhbyyh/ldaMLExample. --- docs/ml-clustering.md | 31 ++++++ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 + .../spark/examples/ml/JavaLDAExample.java | 97 +++++++++++++++++++ .../apache/spark/examples/ml/LDAExample.scala | 77 +++++++++++++++ 5 files changed, 208 insertions(+), 1 deletion(-) create mode 100644 docs/ml-clustering.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md new file mode 100644 index 000000000000..cfefb5dfbde9 --- /dev/null +++ b/docs/ml-clustering.md @@ -0,0 +1,31 @@ +--- +layout: global +title: Clustering - ML +displayTitle: ML - Clustering +--- + +In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). + +## Latent Dirichlet allocation (LDA) + +`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, +and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +`EMLDAOptimizer` to a `DistributedLDAModel` if needed. + +
    + +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. + +{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} +
    + +
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be18a05361a1..6f35b30c3d4d 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,6 +40,7 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -950,4 +951,4 @@ model.transform(test) {% endhighlight %}
    - + \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 91e50ccfecec..54e35fcbb15a 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,6 +69,7 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java new file mode 100644 index 000000000000..3a5d3237c85f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; +// $example on$ +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +/** + * An example demonstrating LDA + * Run with + *
    + * bin/run-example ml.JavaLDAExample
    + * 
    + */ +public class JavaLDAExample { + + // $example on$ + private static class ParseVector implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + + String inputFile = "data/mllib/sample_lda_data.txt"; + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a LDA model + LDA lda = new LDA() + .setK(10) + .setMaxIter(10); + LDAModel model = lda.fit(dataset); + + System.out.println(model.logLikelihood(dataset)); + System.out.println(model.logPerplexity(dataset)); + + // Shows the result + DataFrame topics = model.describeTopics(3); + topics.show(false); + model.transform(dataset).show(false); + + jsc.stop(); + } + // $example off$ +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala new file mode 100644 index 000000000000..419ce3d87a6a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +// $example on$ +import org.apache.spark.ml.clustering.LDA +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} +// $example off$ + +/** + * An example demonstrating a LDA of ML pipeline. + * Run with + * {{{ + * bin/run-example ml.LDAExample + * }}} + */ +object LDAExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + + val input = "data/mllib/sample_lda_data.txt" + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a LDA model + val lda = new LDA() + .setK(10) + .setMaxIter(10) + .setFeaturesCol(FEATURES_COL) + val model = lda.fit(dataset) + val transformed = model.transform(dataset) + + val ll = model.logLikelihood(dataset) + val lp = model.logPerplexity(dataset) + + // describeTopics + val topics = model.describeTopics(3) + + // Shows the result + topics.show(false) + transformed.show(false) + + // $example off$ + sc.stop() + } +} +// scalastyle:on println From de64b65f7cf2ac58c1abc310ba547637fdbb8557 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 30 Nov 2015 15:01:08 -0800 Subject: [PATCH 1597/1824] [SPARK-11975][ML] Remove duplicate mllib example (DT/RF/GBT in Java/Python) Remove duplicate mllib example (DT/RF/GBT in Java/Python). Since we have tutorial code for DT/RF/GBT classification/regression in Scala/Java/Python and example applications for DT/RF/GBT in Scala, so we mark these as duplicated and remove them. mengxr Author: Yanbo Liang Closes #9954 from yanboliang/SPARK-11975. --- .../examples/mllib/JavaDecisionTree.java | 116 -------------- .../mllib/JavaGradientBoostedTreesRunner.java | 126 --------------- .../mllib/JavaRandomForestExample.java | 139 ----------------- .../main/python/mllib/decision_tree_runner.py | 144 ------------------ .../python/mllib/gradient_boosted_trees.py | 77 ---------- .../python/mllib/random_forest_example.py | 90 ----------- 6 files changed, 692 deletions(-) delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java delete mode 100755 examples/src/main/python/mllib/decision_tree_runner.py delete mode 100644 examples/src/main/python/mllib/gradient_boosted_trees.py delete mode 100755 examples/src/main/python/mllib/random_forest_example.py diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java deleted file mode 100644 index 1f82e3f4cb18..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.HashMap; - -import scala.Tuple2; - -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -/** - * Classification and regression using decision trees. - */ -public final class JavaDecisionTree { - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - if (args.length == 1) { - datapath = args[0]; - } else if (args.length > 1) { - System.err.println("Usage: JavaDecisionTree "); - System.exit(1); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - - // Set parameters. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - String impurity = "gini"; - Integer maxDepth = 5; - Integer maxBins = 32; - - // Train a DecisionTree model for classification. - final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - - // Train a DecisionTree model for regression. - impurity = "variance"; - final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - - // Evaluate model on training instances and compute training error - JavaPairRDD regressorPredictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(regressionModel.predict(p.features()), p.label()); - } - }); - Double trainMSE = - regressorPredictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + regressionModel); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java deleted file mode 100644 index a1844d5d07ad..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -/** - * Classification and regression using gradient-boosted decision trees. - */ -public final class JavaGradientBoostedTreesRunner { - - private static void usage() { - System.err.println("Usage: JavaGradientBoostedTreesRunner " + - " "); - System.exit(-1); - } - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - String algo = "Classification"; - if (args.length >= 1) { - datapath = args[0]; - } - if (args.length >= 2) { - algo = args[1]; - } - if (args.length > 2) { - usage(); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Set parameters. - // Note: All features are treated as continuous. - BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); - boostingStrategy.setNumIterations(10); - boostingStrategy.treeStrategy().setMaxDepth(5); - - if (algo.equals("Classification")) { - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - boostingStrategy.treeStrategy().setNumClasses(numClasses); - - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - } else if (algo.equals("Regression")) { - // Train a GradientBoosting model for classification. - final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD predictionAndLabel = - data.mapToPair(new PairFunction() { - @Override public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double trainMSE = - predictionAndLabel.map(new Function, Double>() { - @Override public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + model); - } else { - usage(); - } - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java deleted file mode 100644 index 89a4e092a5af..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestExample.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import java.util.HashMap; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; - -public final class JavaRandomForestExample { - - /** - * Note: This example illustrates binary classification. - * For information on multiclass classification, please refer to the JavaDecisionTree.java - * example. - */ - private static void testClassification(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - Integer numClasses = 2; - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "gini"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); - System.out.println("Test Error: " + testErr); - System.out.println("Learned classification forest model:\n" + model.toDebugString()); - } - - private static void testRegression(JavaRDD trainingData, - JavaRDD testData) { - // Train a RandomForest model. - // Empty categoricalFeaturesInfo indicates all features are continuous. - HashMap categoricalFeaturesInfo = new HashMap(); - Integer numTrees = 3; // Use more in practice. - String featureSubsetStrategy = "auto"; // Let the algorithm choose. - String impurity = "variance"; - Integer maxDepth = 4; - Integer maxBins = 32; - Integer seed = 12345; - - final RandomForestModel model = RandomForest.trainRegressor(trainingData, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - - // Evaluate model on test instances and compute test error - JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); - Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); - System.out.println("Test Mean Squared Error: " + testMSE); - System.out.println("Learned regression forest model:\n" + model.toDebugString()); - } - - public static void main(String[] args) { - SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestExample"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - // Load and parse the data file. - String datapath = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); - // Split the data into training and test sets (30% held out for testing) - JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); - JavaRDD trainingData = splits[0]; - JavaRDD testData = splits[1]; - - System.out.println("\nRunning example of classification using RandomForest\n"); - testClassification(trainingData, testData); - - System.out.println("\nRunning example of regression using RandomForest\n"); - testRegression(trainingData, testData); - sc.stop(); - } -} diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py deleted file mode 100755 index 513ed8fd5145..000000000000 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ /dev/null @@ -1,144 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Decision tree classification and regression using MLlib. - -This example requires NumPy (http://www.numpy.org/). -""" -from __future__ import print_function - -import numpy -import os -import sys - -from operator import add - -from pyspark import SparkContext -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree -from pyspark.mllib.util import MLUtils - - -def getAccuracy(dtModel, data): - """ - Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + (x[0] == x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainCorrect / (0.0 + data.count()) - - -def getMSE(dtModel, data): - """ - Return mean squared error (MSE) of DecisionTreeModel on the given - RDD[LabeledPoint]. - """ - seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) - predictions = dtModel.predict(data.map(lambda x: x.features)) - truth = data.map(lambda p: p.label) - trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) - if data.count() == 0: - return 0 - return trainMSE / (0.0 + data.count()) - - -def reindexClassLabels(data): - """ - Re-index class labels in a dataset to the range {0,...,numClasses-1}. - If all labels in that range already appear at least once, - then the returned RDD is the same one (without a mapping). - Note: If a label simply does not appear in the data, - the index will not include it. - Be aware of this when reindexing subsampled data. - :param data: RDD of LabeledPoint where labels are integer values - denoting labels for a classification problem. - :return: Pair (reindexedData, origToNewLabels) where - reindexedData is an RDD of LabeledPoint with labels in - the range {0,...,numClasses-1}, and - origToNewLabels is a dictionary mapping original labels - to new labels. - """ - # classCounts: class --> # examples in class - classCounts = data.map(lambda x: x.label).countByValue() - numExamples = sum(classCounts.values()) - sortedClasses = sorted(classCounts.keys()) - numClasses = len(classCounts) - # origToNewLabels: class --> index in 0,...,numClasses-1 - if (numClasses < 2): - print("Dataset for classification should have at least 2 classes." - " The given dataset had only %d classes." % numClasses, file=sys.stderr) - exit(1) - origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) - - print("numClasses = %d" % numClasses) - print("Per-class example fractions, counts:") - print("Class\tFrac\tCount") - for c in sortedClasses: - frac = classCounts[c] / (numExamples + 0.0) - print("%g\t%g\t%d" % (c, frac, classCounts[c])) - - if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): - return (data, origToNewLabels) - else: - reindexedData = \ - data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) - return (reindexedData, origToNewLabels) - - -def usage(): - print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr) - exit(1) - - -if __name__ == "__main__": - if len(sys.argv) > 2: - usage() - sc = SparkContext(appName="PythonDT") - - # Load data. - dataPath = 'data/mllib/sample_libsvm_data.txt' - if len(sys.argv) == 2: - dataPath = sys.argv[1] - if not os.path.isfile(dataPath): - sc.stop() - usage() - points = MLUtils.loadLibSVMFile(sc, dataPath) - - # Re-index class labels if needed. - (reindexedData, origToNewLabels) = reindexClassLabels(points) - numClasses = len(origToNewLabels) - - # Train a classifier. - categoricalFeaturesInfo = {} # no categorical features - model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses, - categoricalFeaturesInfo=categoricalFeaturesInfo) - # Print learned tree and stats. - print("Trained DecisionTree for classification:") - print(" Model numNodes: %d" % model.numNodes()) - print(" Model depth: %d" % model.depth()) - print(" Training accuracy: %g" % getAccuracy(model, reindexedData)) - if model.numNodes() < 20: - print(model.toDebugString()) - else: - print(model) - - sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py deleted file mode 100644 index 781bd61c9d2b..000000000000 --- a/examples/src/main/python/mllib/gradient_boosted_trees.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Gradient boosted Trees classification and regression using MLlib. -""" -from __future__ import print_function - -import sys - -from pyspark.context import SparkContext -from pyspark.mllib.tree import GradientBoostedTrees -from pyspark.mllib.util import MLUtils - - -def testClassification(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification ensemble model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): - # Train a GradientBoostedTrees model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numIterations=30, maxDepth=4) - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \ - / float(testData.count()) - print('Test Mean Squared Error = ' + str(testMSE)) - print('Learned regression ensemble model:') - print(model.toDebugString()) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGradientBoostedTrees") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using GradientBoostedTrees\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using GradientBoostedTrees\n') - testRegression(trainingData, testData) - - sc.stop() diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_example.py deleted file mode 100755 index 4cfdad868c66..000000000000 --- a/examples/src/main/python/mllib/random_forest_example.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Random Forest classification and regression using MLlib. - -Note: This example illustrates binary classification. - For information on multiclass classification, please refer to the decision_tree_runner.py - example. -""" -from __future__ import print_function - -import sys - -from pyspark.context import SparkContext -from pyspark.mllib.tree import RandomForest -from pyspark.mllib.util import MLUtils - - -def testClassification(trainingData, testData): - # Train a RandomForest model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - # Note: Use larger numTrees in practice. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - model = RandomForest.trainClassifier(trainingData, numClasses=2, - categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='gini', maxDepth=4, maxBins=32) - - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\ - / float(testData.count()) - print('Test Error = ' + str(testErr)) - print('Learned classification forest model:') - print(model.toDebugString()) - - -def testRegression(trainingData, testData): - # Train a RandomForest model. - # Empty categoricalFeaturesInfo indicates all features are continuous. - # Note: Use larger numTrees in practice. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='variance', maxDepth=4, maxBins=32) - - # Evaluate model on test instances and compute test error - predictions = model.predict(testData.map(lambda x: x.features)) - labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\ - .sum() / float(testData.count()) - print('Test Mean Squared Error = ' + str(testMSE)) - print('Learned regression forest model:') - print(model.toDebugString()) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - - # Load and parse the data file into an RDD of LabeledPoint. - data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') - # Split the data into training and test sets (30% held out for testing) - (trainingData, testData) = data.randomSplit([0.7, 0.3]) - - print('\nRunning example of classification using RandomForest\n') - testClassification(trainingData, testData) - - print('\nRunning example of regression using RandomForest\n') - testRegression(trainingData, testData) - - sc.stop() From 55358889309cf2d856b72e72e0f3081dfdf61cfa Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 30 Nov 2015 15:38:44 -0800 Subject: [PATCH 1598/1824] [SPARK-11960][MLLIB][DOC] User guide for streaming tests CC jkbradley mengxr josepablocam Author: Feynman Liang Closes #10005 from feynmanliang/streaming-test-user-guide. --- docs/mllib-guide.md | 1 + docs/mllib-statistics.md | 25 +++++++++++++++++++ .../examples/mllib/StreamingTestExample.scala | 2 ++ 3 files changed, 28 insertions(+) diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 54e35fcbb15a..43772adcf26e 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -34,6 +34,7 @@ We list major functionality from both below, with links to detailed guides. * [correlations](mllib-statistics.html#correlations) * [stratified sampling](mllib-statistics.html#stratified-sampling) * [hypothesis testing](mllib-statistics.html#hypothesis-testing) + * [streaming significance testing](mllib-statistics.html#streaming-significance-testing) * [random data generation](mllib-statistics.html#random-data-generation) * [Classification and regression](mllib-classification-regression.html) * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index ade5b0768aef..de209f68e19c 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -521,6 +521,31 @@ print(testResult) # summary of the test including the p-value, test statistic, +### Streaming Significance Testing +MLlib provides online implementations of some tests to support use cases +like A/B testing. These tests may be performed on a Spark Streaming +`DStream[(Boolean,Double)]` where the first element of each tuple +indicates control group (`false`) or treatment group (`true`) and the +second element is the value of an observation. + +Streaming significance testing supports the following parameters: + +* `peacePeriod` - The number of initial data points from the stream to +ignore, used to mitigate novelty effects. +* `windowSize` - The number of past batches to perform hypothesis +testing over. Setting to `0` will perform cumulative processing using +all prior batches. + + +
    +
    +[`StreamingTest`](api/scala/index.html#org.apache.spark.mllib.stat.test.StreamingTest) +provides streaming hypothesis testing. + +{% include_example scala/org/apache/spark/examples/mllib/StreamingTestExample.scala %} +
    +
    + ## Random data generation diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index ab29f90254d3..b6677c647663 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -64,6 +64,7 @@ object StreamingTestExample { dir.toString }) + // $example on$ val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { case Array(label, value) => (label.toBoolean, value.toDouble) }) @@ -75,6 +76,7 @@ object StreamingTestExample { val out = streamingTest.registerStream(data) out.print() + // $example off$ // Stop processing if test becomes significant or we time out var timeoutCounter = numBatchesTimeout From ecc00ec3faf09b0e21bc4b468cb45e45d05cf482 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 15:42:10 -0800 Subject: [PATCH 1599/1824] fix Maven build --- .../main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 507641ff8263..a78177751c9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -43,7 +43,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SQLContext.getActive().get + protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null) protected def sparkContext = sqlContext.sparkContext From edb26e7f4e1164645971c9a139eb29ddec8acc5d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 30 Nov 2015 16:31:59 -0800 Subject: [PATCH 1600/1824] [SPARK-12058][HOTFIX] Disable KinesisStreamTests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit KinesisStreamTests in test.py is broken because of #9403. See https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/46896/testReport/(root)/KinesisStreamTests/test_kinesis_stream/ Because Streaming Python didn’t work when merging https://github.com/apache/spark/pull/9403, the PR build didn’t report the Python test failure actually. This PR just disabled the test to unblock #10039 Author: Shixiong Zhu Closes #10047 from zsxwing/disable-python-kinesis-test. --- python/pyspark/streaming/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d380d697bc51..a647e6bf3958 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1409,6 +1409,7 @@ def test_kinesis_stream_api(self): InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, "awsAccessKey", "awsSecretKey") + @unittest.skip("Enable it when we fix SPAKR-12058") def test_kinesis_stream(self): if not are_kinesis_tests_enabled: sys.stderr.write( From d3ca8cfac286ae19f8bedc736877ea9d0a0a072c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 30 Nov 2015 16:37:27 -0800 Subject: [PATCH 1601/1824] [SPARK-12000] Fix API doc generation issues This pull request fixes multiple issues with API doc generation. - Modify the Jekyll plugin so that the entire doc build fails if API docs cannot be generated. This will make it easy to detect when the doc build breaks, since this will now trigger Jenkins failures. - Change how we handle the `-target` compiler option flag in order to fix `javadoc` generation. - Incorporate doc changes from thunterdb (in #10048). Closes #10048. Author: Josh Rosen Author: Timothy Hunter Closes #10049 from JoshRosen/fix-doc-build. --- docs/_plugins/copy_api_dirs.rb | 6 +++--- .../apache/spark/network/client/StreamCallback.java | 4 ++-- .../org/apache/spark/network/server/RpcHandler.java | 2 +- project/SparkBuild.scala | 11 ++++++++--- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 01718d98dffe..f2f3e2e65314 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -27,7 +27,7 @@ cd("..") puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl clean compile unidoc` + system("build/sbt -Pkinesis-asl clean compile unidoc") || raise("Unidoc generation failed") puts "Moving back into docs dir." cd("docs") @@ -117,7 +117,7 @@ puts "Moving to python/docs directory and building sphinx." cd("../python/docs") - puts `make html` + system(make html) || raise("Python doc generation failed") puts "Moving back into home dir." cd("../../") @@ -131,7 +131,7 @@ # Build SparkR API docs puts "Moving to R directory and building roxygen docs." cd("R") - puts `./create-docs.sh` + system("./create-docs.sh") || raise("R doc generation failed") puts "Moving back into home dir." cd("../") diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java index 093fada320cc..51d34cac6e63 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -21,8 +21,8 @@ import java.nio.ByteBuffer; /** - * Callback for streaming data. Stream data will be offered to the {@link onData(ByteBuffer)} - * method as it arrives. Once all the stream data is received, {@link onComplete()} will be + * Callback for streaming data. Stream data will be offered to the {@link onData(String, ByteBuffer)} + * method as it arrives. Once all the stream data is received, {@link onComplete(String)} will be * called. *

    * The network library guarantees that a single thread will call these methods at a time, but diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 65109ddfe13b..1a11f7b3820c 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -55,7 +55,7 @@ public abstract void receive( /** * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link receive(TransportClient, byte[], RpcResponseCallback}" and log a warning if + * call "{@link receive(TransportClient, byte[], RpcResponseCallback)}" and log a warning if * any of the callback methods are called. * * @param client A channel client which enables the handler to make requests back to the sender diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f575f0012d59..63290d8a666e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -160,7 +160,12 @@ object SparkBuild extends PomBuild { javacOptions in Compile ++= Seq( "-encoding", "UTF-8", - "-source", javacJVMVersion.value, + "-source", javacJVMVersion.value + ), + // This -target option cannot be set in the Compile configuration scope since `javadoc` doesn't + // play nicely with it; see https://github.com/sbt/sbt/issues/355#issuecomment-3817629 for + // additional discussion and explanation. + javacOptions in (Compile, compile) ++= Seq( "-target", javacJVMVersion.value ), @@ -547,9 +552,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn, testTags), // Skip actual catalyst, but include the subproject. // Catalyst is not public API and contains quasiquotes which break scaladoc. From e6dc89a33951e9197a77dbcacf022c27469ae41e Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 30 Nov 2015 17:18:44 -0800 Subject: [PATCH 1602/1824] [SPARK-12035] Add more debug information in include_example tag of Jekyll https://issues.apache.org/jira/browse/SPARK-12035 When we debuging lots of example code files, like in https://github.com/apache/spark/pull/10002, it's hard to know which file causes errors due to limited information in `include_example.rb`. With their filenames, we can locate bugs easily. Author: Xusen Yin Closes #10026 from yinxusen/SPARK-12035. --- docs/_plugins/include_example.rb | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 564c86680f68..f7485826a762 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -75,10 +75,10 @@ def select_lines(code) .select { |l, i| l.include? "$example off$" } .map { |l, i| i } - raise "Start indices amount is not equal to end indices amount, please check the code." \ + raise "Start indices amount is not equal to end indices amount, see #{@file}." \ unless startIndices.size == endIndices.size - raise "No code is selected by include_example, please check the code." \ + raise "No code is selected by include_example, see #{@file}." \ if startIndices.size == 0 # Select and join code blocks together, with a space line between each of two continuous @@ -86,8 +86,10 @@ def select_lines(code) lastIndex = -1 result = "" startIndices.zip(endIndices).each do |start, endline| - raise "Overlapping between two example code blocks are not allowed." if start <= lastIndex - raise "$example on$ should not be in the same line with $example off$." if start == endline + raise "Overlapping between two example code blocks are not allowed, see #{@file}." \ + if start <= lastIndex + raise "$example on$ should not be in the same line with $example off$, see #{@file}." \ + if start == endline lastIndex = endline range = Range.new(start + 1, endline - 1) result += trim_codeblock(lines[range]).join From 0a46e4377216a1f7de478f220c3b3042a77789e2 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Mon, 30 Nov 2015 17:19:26 -0800 Subject: [PATCH 1603/1824] [SPARK-12037][CORE] initialize heartbeatReceiverRef before calling startDriverHeartbeat https://issues.apache.org/jira/browse/SPARK-12037 a simple fix by changing the order of the statements Author: CodingCat Closes #10032 from CodingCat/SPARK-12037. --- .../main/scala/org/apache/spark/executor/Executor.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 6154f06e3ac1..7b68dfe5ad06 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -109,6 +109,10 @@ private[spark] class Executor( // Executor for the heartbeat task. private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") + // must be initialized before running startDriverHeartbeat() + private val heartbeatReceiverRef = + RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) + startDriverHeartbeater() def launchTask( @@ -411,9 +415,6 @@ private[spark] class Executor( } } - private val heartbeatReceiverRef = - RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv) - /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { // list of (task id, metrics) to send back to the driver From 9bf2120672ae0f620a217ccd96bef189ab75e0d6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 30 Nov 2015 17:22:05 -0800 Subject: [PATCH 1604/1824] [SPARK-12007][NETWORK] Avoid copies in the network lib's RPC layer. This change seems large, but most of it is just replacing `byte[]` with `ByteBuffer` and `new byte[]` with `ByteBuffer.allocate()`, since it changes the network library's API. The following are parts of the code that actually have meaningful changes: - The Message implementations were changed to inherit from a new AbstractMessage that can optionally hold a reference to a body (in the form of a ManagedBuffer); this is similar to how ResponseWithBody worked before, except now it's not restricted to just responses. - The TransportFrameDecoder was pretty much rewritten to avoid copies as much as possible; it doesn't rely on CompositeByteBuf to accumulate incoming data anymore, since CompositeByteBuf has issues when slices are retained. The code now is able to create frames without having to resort to copying bytes except for a few bytes (containing the frame length) in very rare cases. - Some minor changes in the SASL layer to convert things back to `byte[]` since the JDK SASL API operates on those. Author: Marcelo Vanzin Closes #9987 from vanzin/SPARK-12007. --- .../mesos/MesosExternalShuffleService.scala | 3 +- .../network/netty/NettyBlockRpcServer.scala | 8 +- .../netty/NettyBlockTransferService.scala | 6 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 16 +- .../org/apache/spark/rpc/netty/Outbox.scala | 9 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 3 +- .../network/client/RpcResponseCallback.java | 4 +- .../spark/network/client/TransportClient.java | 16 +- .../client/TransportResponseHandler.java | 16 +- .../network/protocol/AbstractMessage.java | 54 +++++++ ...Body.java => AbstractResponseMessage.java} | 16 +- .../network/protocol/ChunkFetchFailure.java | 2 +- .../network/protocol/ChunkFetchRequest.java | 2 +- .../network/protocol/ChunkFetchSuccess.java | 8 +- .../spark/network/protocol/Message.java | 11 +- .../network/protocol/MessageEncoder.java | 29 ++-- .../spark/network/protocol/OneWayMessage.java | 33 ++-- .../spark/network/protocol/RpcFailure.java | 2 +- .../spark/network/protocol/RpcRequest.java | 34 +++-- .../spark/network/protocol/RpcResponse.java | 39 +++-- .../spark/network/protocol/StreamFailure.java | 2 +- .../spark/network/protocol/StreamRequest.java | 2 +- .../network/protocol/StreamResponse.java | 19 +-- .../network/sasl/SaslClientBootstrap.java | 12 +- .../spark/network/sasl/SaslMessage.java | 31 ++-- .../spark/network/sasl/SaslRpcHandler.java | 26 +++- .../spark/network/server/MessageHandler.java | 2 +- .../spark/network/server/NoOpRpcHandler.java | 8 +- .../spark/network/server/RpcHandler.java | 8 +- .../server/TransportChannelHandler.java | 2 +- .../server/TransportRequestHandler.java | 15 +- .../apache/spark/network/util/JavaUtils.java | 48 ++++-- .../network/util/TransportFrameDecoder.java | 142 ++++++++++++------ .../network/ChunkFetchIntegrationSuite.java | 5 +- .../apache/spark/network/ProtocolSuite.java | 10 +- .../RequestTimeoutIntegrationSuite.java | 51 ++++--- .../spark/network/RpcIntegrationSuite.java | 26 ++-- .../org/apache/spark/network/StreamSuite.java | 5 +- .../TransportResponseHandlerSuite.java | 24 +-- .../spark/network/sasl/SparkSaslSuite.java | 43 ++++-- .../util/TransportFrameDecoderSuite.java | 23 ++- .../shuffle/ExternalShuffleBlockHandler.java | 9 +- .../shuffle/ExternalShuffleClient.java | 3 +- .../shuffle/OneForOneBlockFetcher.java | 7 +- .../mesos/MesosExternalShuffleClient.java | 5 +- .../protocol/BlockTransferMessage.java | 8 +- .../network/sasl/SaslIntegrationSuite.java | 18 ++- .../shuffle/BlockTransferMessagesSuite.java | 2 +- .../ExternalShuffleBlockHandlerSuite.java | 21 +-- .../shuffle/OneForOneBlockFetcherSuite.java | 8 +- 50 files changed, 589 insertions(+), 307 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java rename network/common/src/main/java/org/apache/spark/network/protocol/{ResponseWithBody.java => AbstractResponseMessage.java} (63%) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 12337a940a41..8ffcfc0878a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.mesos import java.net.SocketAddress +import java.nio.ByteBuffer import scala.collection.mutable @@ -56,7 +57,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo } } connectedApps(address) = appId - callback.onSuccess(new Array[Byte](0)) + callback.onSuccess(ByteBuffer.allocate(0)) case _ => super.handleMessage(message, client, callback) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 76968249fb62..df8c21fb837e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -47,9 +47,9 @@ class NettyBlockRpcServer( override def receive( client: TransportClient, - messageBytes: Array[Byte], + rpcMessage: ByteBuffer, responseContext: RpcResponseCallback): Unit = { - val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) + val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) logTrace(s"Received request: $message") message match { @@ -58,7 +58,7 @@ class NettyBlockRpcServer( openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel is serialized as bytes using our JavaSerializer. @@ -66,7 +66,7 @@ class NettyBlockRpcServer( serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) - responseContext.onSuccess(new Array[Byte](0)) + responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b0694e3c6c8a..82c16e855b0c 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} @@ -133,9 +135,9 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage data } - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer, new RpcResponseCallback { - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") result.success((): Unit) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index c7d74fa1d919..68c5f44145b0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -241,16 +241,14 @@ private[netty] class NettyRpcEnv( promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } - private[netty] def serialize(content: Any): Array[Byte] = { - val buffer = javaSerializerInstance.serialize(content) - java.util.Arrays.copyOfRange( - buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) + private[netty] def serialize(content: Any): ByteBuffer = { + javaSerializerInstance.serialize(content) } - private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = { + private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { NettyRpcEnv.currentClient.withValue(client) { deserialize { () => - javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + javaSerializerInstance.deserialize[T](bytes) } } } @@ -557,7 +555,7 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: Array[Byte], + message: ByteBuffer, callback: RpcResponseCallback): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postRemoteMessage(messageToDispatch, callback) @@ -565,12 +563,12 @@ private[netty] class NettyRpcHandler( override def receive( client: TransportClient, - message: Array[Byte]): Unit = { + message: ByteBuffer): Unit = { val messageToDispatch = internalReceive(client, message) dispatcher.postOneWayMessage(messageToDispatch) } - private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = { + private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 36fdd00bbc4c..2316ebe347bb 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -17,6 +17,7 @@ package org.apache.spark.rpc.netty +import java.nio.ByteBuffer import java.util.concurrent.Callable import javax.annotation.concurrent.GuardedBy @@ -34,7 +35,7 @@ private[netty] sealed trait OutboxMessage { } -private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage +private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage with Logging { override def sendWith(client: TransportClient): Unit = { @@ -48,9 +49,9 @@ private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends Outb } private[netty] case class RpcOutboxMessage( - content: Array[Byte], + content: ByteBuffer, _onFailure: (Throwable) => Unit, - _onSuccess: (TransportClient, Array[Byte]) => Unit) + _onSuccess: (TransportClient, ByteBuffer) => Unit) extends OutboxMessage with RpcResponseCallback { private var client: TransportClient = _ @@ -70,7 +71,7 @@ private[netty] case class RpcOutboxMessage( _onFailure(e) } - override def onSuccess(response: Array[Byte]): Unit = { + override def onSuccess(response: ByteBuffer): Unit = { _onSuccess(client, response) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 323184cdd9b6..ebd6f700710b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.netty import java.net.InetSocketAddress +import java.nio.ByteBuffer import io.netty.channel.Channel import org.mockito.Mockito._ @@ -32,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6ec960d79542..47e93f9846fa 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -17,13 +17,15 @@ package org.apache.spark.network.client; +import java.nio.ByteBuffer; + /** * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ public interface RpcResponseCallback { /** Successful serialized result from server. */ - void onSuccess(byte[] response); + void onSuccess(ByteBuffer response); /** Exception either propagated from server or raised on client side. */ void onFailure(Throwable e); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8a58e7b24585..c49ca4d5ee92 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -20,6 +20,7 @@ import java.io.Closeable; import java.io.IOException; import java.net.SocketAddress; +import java.nio.ByteBuffer; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -36,6 +37,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.protocol.ChunkFetchRequest; import org.apache.spark.network.protocol.OneWayMessage; import org.apache.spark.network.protocol.RpcRequest; @@ -212,7 +214,7 @@ public void operationComplete(ChannelFuture future) throws Exception { * @param callback Callback to handle the RPC's reply. * @return The RPC's id. */ - public long sendRpc(byte[] message, final RpcResponseCallback callback) { + public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); logger.trace("Sending RPC to {}", serverAddr); @@ -220,7 +222,7 @@ public long sendRpc(byte[] message, final RpcResponseCallback callback) { final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( + channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -249,12 +251,12 @@ public void operationComplete(ChannelFuture future) throws Exception { * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. */ - public byte[] sendRpcSync(byte[] message, long timeoutMs) { - final SettableFuture result = SettableFuture.create(); + public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); sendRpc(message, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { result.set(response); } @@ -279,8 +281,8 @@ public void onFailure(Throwable e) { * * @param message The message to send. */ - public void send(byte[] message) { - channel.writeAndFlush(new OneWayMessage(message)); + public void send(ByteBuffer message) { + channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message))); } /** diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 4c15045363b8..23a8dba59344 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -136,7 +136,7 @@ public void exceptionCaught(Throwable cause) { } @Override - public void handle(ResponseMessage message) { + public void handle(ResponseMessage message) throws Exception { String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; @@ -144,11 +144,11 @@ public void handle(ResponseMessage message) { if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, remoteAddress); - resp.body.release(); + resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body); - resp.body.release(); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { ChunkFetchFailure resp = (ChunkFetchFailure) message; @@ -166,10 +166,14 @@ public void handle(ResponseMessage message) { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", - resp.requestId, remoteAddress, resp.response.length); + resp.requestId, remoteAddress, resp.body().size()); } else { outstandingRpcs.remove(resp.requestId); - listener.onSuccess(resp.response); + try { + listener.onSuccess(resp.body().nioByteBuffer()); + } finally { + resp.body().release(); + } } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java new file mode 100644 index 000000000000..2924218c2f08 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Abstract class for messages which optionally contain a body kept in a separate buffer. + */ +public abstract class AbstractMessage implements Message { + private final ManagedBuffer body; + private final boolean isBodyInFrame; + + protected AbstractMessage() { + this(null, false); + } + + protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) { + this.body = body; + this.isBodyInFrame = isBodyInFrame; + } + + @Override + public ManagedBuffer body() { + return body; + } + + @Override + public boolean isBodyInFrame() { + return isBodyInFrame; + } + + protected boolean equals(AbstractMessage other) { + return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java similarity index 63% rename from network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java rename to network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java index 67be77e39f71..c362c92fc4f5 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java @@ -17,23 +17,15 @@ package org.apache.spark.network.protocol; -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; /** - * Abstract class for response messages that contain a large data portion kept in a separate - * buffer. These messages are treated especially by MessageEncoder. + * Abstract class for response messages. */ -public abstract class ResponseWithBody implements ResponseMessage { - public final ManagedBuffer body; - public final boolean isBodyInFrame; +public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage { - protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) { - this.body = body; - this.isBodyInFrame = isBodyInFrame; + protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) { + super(body, isBodyInFrame); } public abstract ResponseMessage createFailureResponse(String error); diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index f0363830b61a..7b28a9a96948 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -23,7 +23,7 @@ /** * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ -public final class ChunkFetchFailure implements ResponseMessage { +public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage { public final StreamChunkId streamChunkId; public final String errorString; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index 5a173af54f61..26d063feb5fe 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -24,7 +24,7 @@ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class ChunkFetchRequest implements RequestMessage { +public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage { public final StreamChunkId streamChunkId; public ChunkFetchRequest(StreamChunkId streamChunkId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index e6a7e9a8b414..94c2ac9b20e4 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -30,7 +30,7 @@ * may be written by Netty in a more efficient manner (i.e., zero-copy write). * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ -public final class ChunkFetchSuccess extends ResponseWithBody { +public final class ChunkFetchSuccess extends AbstractResponseMessage { public final StreamChunkId streamChunkId; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { @@ -67,14 +67,14 @@ public static ChunkFetchSuccess decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(streamChunkId, body); + return Objects.hashCode(streamChunkId, body()); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchSuccess) { ChunkFetchSuccess o = (ChunkFetchSuccess) other; - return streamChunkId.equals(o.streamChunkId) && body.equals(o.body); + return streamChunkId.equals(o.streamChunkId) && super.equals(o); } return false; } @@ -83,7 +83,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("streamChunkId", streamChunkId) - .add("buffer", body) + .add("buffer", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java index 39afd03db60e..66f5b8b3a59c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -19,17 +19,25 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; + /** An on-the-wire transmittable message. */ public interface Message extends Encodable { /** Used to identify this request type. */ Type type(); + /** An optional body for the message. */ + ManagedBuffer body(); + + /** Whether to include the body of the message in the same frame as the message. */ + boolean isBodyInFrame(); + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ public static enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9); + OneWayMessage(9), User(-1); private final byte id; @@ -57,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 6cce97c807dc..abca22347b78 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -42,25 +42,28 @@ public final class MessageEncoder extends MessageToMessageEncoder { * data to 'out', in order to enable zero-copy transfer. */ @Override - public void encode(ChannelHandlerContext ctx, Message in, List out) { + public void encode(ChannelHandlerContext ctx, Message in, List out) throws Exception { Object body = null; long bodyLength = 0; boolean isBodyInFrame = false; - // Detect ResponseWithBody messages and get the data buffer out of them. - // The body is used in order to enable zero-copy transfer for the payload. - if (in instanceof ResponseWithBody) { - ResponseWithBody resp = (ResponseWithBody) in; + // If the message has a body, take it out to enable zero-copy transfer for the payload. + if (in.body() != null) { try { - bodyLength = resp.body.size(); - body = resp.body.convertToNetty(); - isBodyInFrame = resp.isBodyInFrame; + bodyLength = in.body().size(); + body = in.body().convertToNetty(); + isBodyInFrame = in.isBodyInFrame(); } catch (Exception e) { - // Re-encode this message as a failure response. - String error = e.getMessage() != null ? e.getMessage() : "null"; - logger.error(String.format("Error processing %s for client %s", - resp, ctx.channel().remoteAddress()), e); - encode(ctx, resp.createFailureResponse(error), out); + if (in instanceof AbstractResponseMessage) { + AbstractResponseMessage resp = (AbstractResponseMessage) in; + // Re-encode this message as a failure response. + String error = e.getMessage() != null ? e.getMessage() : "null"; + logger.error(String.format("Error processing %s for client %s", + in, ctx.channel().remoteAddress()), e); + encode(ctx, resp.createFailureResponse(error), out); + } else { + throw e; + } return; } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java index 95a0270be3da..efe0470f3587 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java @@ -17,21 +17,21 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** * A RPC that does not expect a reply, which is handled by a remote * {@link org.apache.spark.network.server.RpcHandler}. */ -public final class OneWayMessage implements RequestMessage { - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; +public final class OneWayMessage extends AbstractMessage implements RequestMessage { - public OneWayMessage(byte[] message) { - this.message = message; + public OneWayMessage(ManagedBuffer body) { + super(body, true); } @Override @@ -39,29 +39,34 @@ public OneWayMessage(byte[] message) { @Override public int encodedLength() { - return Encoders.ByteArrays.encodedLength(message); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 4; } @Override public void encode(ByteBuf buf) { - Encoders.ByteArrays.encode(buf, message); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static OneWayMessage decode(ByteBuf buf) { - byte[] message = Encoders.ByteArrays.decode(buf); - return new OneWayMessage(message); + // See comment in encodedLength(). + buf.readInt(); + return new OneWayMessage(new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Arrays.hashCode(message); + return Objects.hashCode(body()); } @Override public boolean equals(Object other) { if (other instanceof OneWayMessage) { OneWayMessage o = (OneWayMessage) other; - return Arrays.equals(message, o.message); + return super.equals(o); } return false; } @@ -69,7 +74,7 @@ public boolean equals(Object other) { @Override public String toString() { return Objects.toStringHelper(this) - .add("message", message) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index 2dfc7876ba32..a76624ef5dc9 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -21,7 +21,7 @@ import io.netty.buffer.ByteBuf; /** Response to {@link RpcRequest} for a failed RPC. */ -public final class RpcFailure implements ResponseMessage { +public final class RpcFailure extends AbstractMessage implements ResponseMessage { public final long requestId; public final String errorString; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 745039db742f..96213794a801 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -17,26 +17,25 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. * This will correspond to a single * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). */ -public final class RpcRequest implements RequestMessage { +public final class RpcRequest extends AbstractMessage implements RequestMessage { /** Used to link an RPC request with its response. */ public final long requestId; - /** Serialized message to send to remote RpcHandler. */ - public final byte[] message; - - public RpcRequest(long requestId, byte[] message) { + public RpcRequest(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.message = message; } @Override @@ -44,31 +43,36 @@ public RpcRequest(long requestId, byte[] message) { @Override public int encodedLength() { - return 8 + Encoders.ByteArrays.encodedLength(message); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, message); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static RpcRequest decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] message = Encoders.ByteArrays.decode(buf); - return new RpcRequest(requestId, message); + // See comment in encodedLength(). + buf.readInt(); + return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(message)); + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcRequest) { RpcRequest o = (RpcRequest) other; - return requestId == o.requestId && Arrays.equals(message, o.message); + return requestId == o.requestId && super.equals(o); } return false; } @@ -77,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("message", message) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index 1671cd444f03..bae866e14a1e 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -17,49 +17,62 @@ package org.apache.spark.network.protocol; -import java.util.Arrays; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; /** Response to {@link RpcRequest} for a successful RPC. */ -public final class RpcResponse implements ResponseMessage { +public final class RpcResponse extends AbstractResponseMessage { public final long requestId; - public final byte[] response; - public RpcResponse(long requestId, byte[] response) { + public RpcResponse(long requestId, ManagedBuffer message) { + super(message, true); this.requestId = requestId; - this.response = response; } @Override public Type type() { return Type.RpcResponse; } @Override - public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } + public int encodedLength() { + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 8 + 4; + } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - Encoders.ByteArrays.encode(buf, response); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new RpcFailure(requestId, error); } public static RpcResponse decode(ByteBuf buf) { long requestId = buf.readLong(); - byte[] response = Encoders.ByteArrays.decode(buf); - return new RpcResponse(requestId, response); + // See comment in encodedLength(). + buf.readInt(); + return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain())); } @Override public int hashCode() { - return Objects.hashCode(requestId, Arrays.hashCode(response)); + return Objects.hashCode(requestId, body()); } @Override public boolean equals(Object other) { if (other instanceof RpcResponse) { RpcResponse o = (RpcResponse) other; - return requestId == o.requestId && Arrays.equals(response, o.response); + return requestId == o.requestId && super.equals(o); } return false; } @@ -68,7 +81,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("requestId", requestId) - .add("response", response) + .add("body", body()) .toString(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java index e3dade2ebf90..26747ee55b4d 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java @@ -26,7 +26,7 @@ /** * Message indicating an error when transferring a stream. */ -public final class StreamFailure implements ResponseMessage { +public final class StreamFailure extends AbstractMessage implements ResponseMessage { public final String streamId; public final String error; diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java index 821e8f53884d..35af5a84ba6b 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java @@ -29,7 +29,7 @@ * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before * the data can be streamed. */ -public final class StreamRequest implements RequestMessage { +public final class StreamRequest extends AbstractMessage implements RequestMessage { public final String streamId; public StreamRequest(String streamId) { diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index ac5ab9a323a1..51b899930f72 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -30,15 +30,15 @@ * sender. The receiver is expected to set a temporary channel handler that will consume the * number of bytes this message says the stream has. */ -public final class StreamResponse extends ResponseWithBody { - public final String streamId; - public final long byteCount; +public final class StreamResponse extends AbstractResponseMessage { + public final String streamId; + public final long byteCount; - public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { - super(buffer, false); - this.streamId = streamId; - this.byteCount = byteCount; - } + public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + } @Override public Type type() { return Type.StreamResponse; } @@ -68,7 +68,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId); + return Objects.hashCode(byteCount, streamId, body()); } @Override @@ -85,6 +85,7 @@ public String toString() { return Objects.toStringHelper(this) .add("streamId", streamId) .add("byteCount", byteCount) + .add("body", body()) .toString(); } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 69923769d44b..68381037d689 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,6 +17,8 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; @@ -28,6 +30,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,11 +73,12 @@ public void doBootstrap(TransportClient client, Channel channel) { while (!saslClient.isComplete()) { SaslMessage msg = new SaslMessage(appId, payload); - ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size()); msg.encode(buf); + buf.writeBytes(msg.body().nioByteBuffer()); - byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); - payload = saslClient.response(response); + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); + payload = saslClient.response(JavaUtils.bufferToArray(response)); } client.setClientId(appId); @@ -88,6 +92,8 @@ public void doBootstrap(TransportClient client, Channel channel) { saslClient = null; logger.debug("Channel {} configured for SASL encryption.", client); } + } catch (IOException ioe) { + throw new RuntimeException(ioe); } finally { if (saslClient != null) { try { diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index cad76ab7aa54..e52b526f09c7 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -18,38 +18,50 @@ package org.apache.spark.network.sasl; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; -import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.protocol.AbstractMessage; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged * with the given appId. This appId allows a single SaslRpcHandler to multiplex different * applications which may be using different sets of credentials. */ -class SaslMessage implements Encodable { +class SaslMessage extends AbstractMessage { /** Serialization tag used to catch incorrect payloads. */ private static final byte TAG_BYTE = (byte) 0xEA; public final String appId; - public final byte[] payload; - public SaslMessage(String appId, byte[] payload) { + public SaslMessage(String appId, byte[] message) { + this(appId, Unpooled.wrappedBuffer(message)); + } + + public SaslMessage(String appId, ByteBuf message) { + super(new NettyManagedBuffer(message), true); this.appId = appId; - this.payload = payload; } + @Override + public Type type() { return Type.User; } + @Override public int encodedLength() { - return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); + // The integer (a.k.a. the body size) is not really used, since that information is already + // encoded in the frame length. But this maintains backwards compatibility with versions of + // RpcRequest that use Encoders.ByteArrays. + return 1 + Encoders.Strings.encodedLength(appId) + 4; } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); Encoders.Strings.encode(buf, appId); - Encoders.ByteArrays.encode(buf, payload); + // See comment in encodedLength(). + buf.writeInt((int) body().size()); } public static SaslMessage decode(ByteBuf buf) { @@ -59,7 +71,8 @@ public static SaslMessage decode(ByteBuf buf) { } String appId = Encoders.Strings.decode(buf); - byte[] payload = Encoders.ByteArrays.decode(buf); - return new SaslMessage(appId, payload); + // See comment in encodedLength(). + buf.readInt(); + return new SaslMessage(appId, buf.retain()); } } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 830db94b890c..c215bd9d1504 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -17,8 +17,11 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; +import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import org.slf4j.Logger; @@ -28,6 +31,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; /** @@ -70,14 +74,20 @@ class SaslRpcHandler extends RpcHandler { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); return; } - SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } if (saslServer == null) { // First message in the handshake, setup the necessary state. @@ -86,8 +96,14 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback conf.saslServerAlwaysEncrypt()); } - byte[] response = saslServer.response(saslMessage.payload); - callback.onSuccess(response); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -109,7 +125,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } @Override - public void receive(TransportClient client, byte[] message) { + public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index b80c15106ecb..3843406b2740 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -26,7 +26,7 @@ */ public abstract class MessageHandler { /** Handles the receipt of a single message. */ - public abstract void handle(T message); + public abstract void handle(T message) throws Exception; /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java index 1502b7489e86..6ed61da5c7ef 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -1,5 +1,3 @@ -package org.apache.spark.network.server; - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -17,6 +15,10 @@ * limitations under the License. */ +package org.apache.spark.network.server; + +import java.nio.ByteBuffer; + import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; @@ -29,7 +31,7 @@ public NoOpRpcHandler() { } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { throw new UnsupportedOperationException("Cannot handle messages"); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 1a11f7b3820c..ee1c68369947 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,7 +46,7 @@ public abstract class RpcHandler { */ public abstract void receive( TransportClient client, - byte[] message, + ByteBuffer message, RpcResponseCallback callback); /** @@ -62,7 +64,7 @@ public abstract void receive( * of this RPC. This will always be the exact same object for a particular channel. * @param message The serialized bytes of the RPC. */ - public void receive(TransportClient client, byte[] message) { + public void receive(TransportClient client, ByteBuffer message) { receive(client, message, ONE_WAY_CALLBACK); } @@ -79,7 +81,7 @@ private static class OneWayRpcCallback implements RpcResponseCallback { private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class); @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { logger.warn("Response provided for one-way RPC."); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 3164e0067903..09435bcbab35 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -99,7 +99,7 @@ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { } @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) { + public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); } else { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index db18ea77d107..c864d7ce16bd 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import java.nio.ByteBuffer; + import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.channel.Channel; @@ -26,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.ChunkFetchRequest; @@ -143,10 +146,10 @@ private void processStreamRequest(final StreamRequest req) { private void processRpcRequest(final RpcRequest req) { try { - rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { - respond(new RpcResponse(req.requestId, response)); + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); } @Override @@ -157,14 +160,18 @@ public void onFailure(Throwable e) { } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } finally { + req.body().release(); } } private void processOneWayMessage(OneWayMessage req) { try { - rpcHandler.receive(reverseClient, req.message); + rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() for one-way message.", e); + } finally { + req.body().release(); } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 7d27439cfde7..b3d8e0cd7cdc 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -132,7 +132,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static final ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -164,32 +164,32 @@ private static boolean isSymlink(File file) throws IOException { */ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); - + try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } - + long val = Long.parseLong(m.group(1)); String suffix = m.group(2); - + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); } - + // If suffix is valid use that, otherwise none was provided and use the default passed return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; - + throw new NumberFormatException(timeError + "\n" + e.getMessage()); } } - + /** * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If * no suffix is provided, the passed number is assumed to be in ms. @@ -205,10 +205,10 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } - + /** * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for - * internal use. If no suffix is provided a direct conversion of the provided default is + * internal use. If no suffix is provided a direct conversion of the provided default is * attempted. */ private static long parseByteString(String str, ByteUnit unit) { @@ -217,7 +217,7 @@ private static long parseByteString(String str, ByteUnit unit) { try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); - + if (m.matches()) { long val = Long.parseLong(m.group(1)); String suffix = m.group(2); @@ -228,14 +228,14 @@ private static long parseByteString(String str, ByteUnit unit) { } // If suffix is valid use that, otherwise none was provided and use the default passed - return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); } else if (fractionMatcher.matches()) { - throw new NumberFormatException("Fractional values are not supported. Input was: " + throw new NumberFormatException("Fractional values are not supported. Input was: " + fractionMatcher.group(1)); } else { - throw new NumberFormatException("Failed to parse byte string: " + str); + throw new NumberFormatException("Failed to parse byte string: " + str); } - + } catch (NumberFormatException e) { String timeError = "Size must be specified as bytes (b), " + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + @@ -248,7 +248,7 @@ private static long parseByteString(String str, ByteUnit unit) { /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for * internal use. - * + * * If no suffix is provided, the passed number is assumed to be in bytes. */ public static long byteStringAsBytes(String str) { @@ -264,7 +264,7 @@ public static long byteStringAsBytes(String str) { public static long byteStringAsKb(String str) { return parseByteString(str, ByteUnit.KiB); } - + /** * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for * internal use. @@ -284,4 +284,20 @@ public static long byteStringAsMb(String str) { public static long byteStringAsGb(String str) { return parseByteString(str, ByteUnit.GiB); } + + /** + * Returns a byte array with the buffer's contents, trying to avoid copying the data if + * possible. + */ + public static byte[] bufferToArray(ByteBuffer buffer) { + if (buffer.hasArray() && buffer.arrayOffset() == 0 && + buffer.array().length == buffer.remaining()) { + return buffer.array(); + } else { + byte[] bytes = new byte[buffer.remaining()]; + buffer.get(bytes); + return bytes; + } + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 5889562dd970..a466c729154a 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -17,9 +17,13 @@ package org.apache.spark.network.util; +import java.util.Iterator; +import java.util.LinkedList; + import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -44,84 +48,138 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { public static final String HANDLER_NAME = "frameDecoder"; private static final int LENGTH_SIZE = 8; private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; + private static final int UNKNOWN_FRAME_SIZE = -1; + + private final LinkedList buffers = new LinkedList<>(); + private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); - private CompositeByteBuf buffer; + private long totalSize = 0; + private long nextFrameSize = UNKNOWN_FRAME_SIZE; private volatile Interceptor interceptor; @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { ByteBuf in = (ByteBuf) data; + buffers.add(in); + totalSize += in.readableBytes(); + + while (!buffers.isEmpty()) { + // First, feed the interceptor, and if it's still, active, try again. + if (interceptor != null) { + ByteBuf first = buffers.getFirst(); + int available = first.readableBytes(); + if (feedInterceptor(first)) { + assert !first.isReadable() : "Interceptor still active but buffer has data."; + } - if (buffer == null) { - buffer = in.alloc().compositeBuffer(); - } - - buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes()); - - while (buffer.isReadable()) { - discardReadBytes(); - if (!feedInterceptor()) { + int read = available - first.readableBytes(); + if (read == available) { + buffers.removeFirst().release(); + } + totalSize -= read; + } else { + // Interceptor is not active, so try to decode one frame. ByteBuf frame = decodeNext(); if (frame == null) { break; } - ctx.fireChannelRead(frame); } } - - discardReadBytes(); } - private void discardReadBytes() { - // If the buffer's been retained by downstream code, then make a copy of the remaining - // bytes into a new buffer. Otherwise, just discard stale components. - if (buffer.refCnt() > 1) { - CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer(); + private long decodeFrameSize() { + if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) { + return nextFrameSize; + } - if (buffer.readableBytes() > 0) { - ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes()); - spillBuf.writeBytes(buffer); - newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes()); + // We know there's enough data. If the first buffer contains all the data, great. Otherwise, + // hold the bytes for the frame length in a composite buffer until we have enough data to read + // the frame size. Normally, it should be rare to need more than one buffer to read the frame + // size. + ByteBuf first = buffers.getFirst(); + if (first.readableBytes() >= LENGTH_SIZE) { + nextFrameSize = first.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + if (!first.isReadable()) { + buffers.removeFirst().release(); } + return nextFrameSize; + } - buffer.release(); - buffer = newBuffer; - } else { - buffer.discardReadComponents(); + while (frameLenBuf.readableBytes() < LENGTH_SIZE) { + ByteBuf next = buffers.getFirst(); + int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes()); + frameLenBuf.writeBytes(next, toRead); + if (!next.isReadable()) { + buffers.removeFirst().release(); + } } + + nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE; + totalSize -= LENGTH_SIZE; + frameLenBuf.clear(); + return nextFrameSize; } private ByteBuf decodeNext() throws Exception { - if (buffer.readableBytes() < LENGTH_SIZE) { + long frameSize = decodeFrameSize(); + if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { return null; } - int frameLen = (int) buffer.readLong() - LENGTH_SIZE; - if (buffer.readableBytes() < frameLen) { - buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE); - return null; + // Reset size for next frame. + nextFrameSize = UNKNOWN_FRAME_SIZE; + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + + // If the first buffer holds the entire frame, return it. + int remaining = (int) frameSize; + if (buffers.getFirst().readableBytes() >= remaining) { + return nextBufferForFrame(remaining); } - Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen); - Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen); + // Otherwise, create a composite buffer. + CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(); + while (remaining > 0) { + ByteBuf next = nextBufferForFrame(remaining); + remaining -= next.readableBytes(); + frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes()); + } + assert remaining == 0; + return frame; + } + + /** + * Takes the first buffer in the internal list, and either adjust it to fit in the frame + * (by taking a slice out of it) or remove it from the internal list. + */ + private ByteBuf nextBufferForFrame(int bytesToRead) { + ByteBuf buf = buffers.getFirst(); + ByteBuf frame; + + if (buf.readableBytes() > bytesToRead) { + frame = buf.retain().readSlice(bytesToRead); + totalSize -= bytesToRead; + } else { + frame = buf; + buffers.removeFirst(); + totalSize -= frame.readableBytes(); + } - ByteBuf frame = buffer.readSlice(frameLen); - frame.retain(); return frame; } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { - if (buffer != null) { - if (buffer.isReadable()) { - feedInterceptor(); - } - buffer.release(); + for (ByteBuf b : buffers) { + b.release(); } if (interceptor != null) { interceptor.channelInactive(); } + frameLenBuf.release(); super.channelInactive(ctx); } @@ -141,8 +199,8 @@ public void setInterceptor(Interceptor interceptor) { /** * @return Whether the interceptor is still active after processing the data. */ - private boolean feedInterceptor() throws Exception { - if (interceptor != null && !interceptor.handle(buffer)) { + private boolean feedInterceptor(ByteBuf buf) throws Exception { + if (interceptor != null && !interceptor.handle(buf)) { interceptor = null; } return interceptor != null; diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 50a324e29338..70c849d60e0a 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -107,7 +107,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 1aa20900ffe7..6c8dd742f4b6 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -82,10 +82,10 @@ private void testClientToServer(Message msg) { @Test public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); - testClientToServer(new RpcRequest(12345, new byte[0])); - testClientToServer(new RpcRequest(12345, new byte[100])); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0))); + testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10))); testClientToServer(new StreamRequest("abcde")); - testClientToServer(new OneWayMessage(new byte[100])); + testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); } @Test @@ -94,8 +94,8 @@ public void responses() { testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); - testServerToClient(new RpcResponse(12345, new byte[0])); - testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); + testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100))); testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "")); // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 42955ef69235..f9b5bf96d621 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; import org.junit.*; +import static org.junit.Assert.*; import java.io.IOException; import java.nio.ByteBuffer; @@ -84,13 +85,16 @@ public void tearDown() { @Test public void timeoutInactiveRequests() throws Exception { final Semaphore semaphore = new Semaphore(1); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -110,15 +114,15 @@ public StreamManager getStreamManager() { // First completes quickly (semaphore starts at 1). TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client.sendRpc(new byte[0], callback0); + client.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); - assert (callback0.success.length == response.length); + assertEquals(responseSize, callback0.successLength); } // Second times out after 2 seconds, with slack. Must be IOException. TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client.sendRpc(new byte[0], callback1); + client.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(4 * 1000); assert (callback1.failure != null); assert (callback1.failure instanceof IOException); @@ -131,13 +135,16 @@ public StreamManager getStreamManager() { @Test public void timeoutCleanlyClosesClient() throws Exception { final Semaphore semaphore = new Semaphore(0); - final byte[] response = new byte[16]; + final int responseSize = 16; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { try { semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); - callback.onSuccess(response); + callback.onSuccess(ByteBuffer.allocate(responseSize)); } catch (InterruptedException e) { // do nothing } @@ -158,7 +165,7 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback0 = new TestCallback(); synchronized (callback0) { - client0.sendRpc(new byte[0], callback0); + client0.sendRpc(ByteBuffer.allocate(0), callback0); callback0.wait(FOREVER); assert (callback0.failure instanceof IOException); assert (!client0.isActive()); @@ -170,10 +177,10 @@ public StreamManager getStreamManager() { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback1 = new TestCallback(); synchronized (callback1) { - client1.sendRpc(new byte[0], callback1); + client1.sendRpc(ByteBuffer.allocate(0), callback1); callback1.wait(FOREVER); - assert (callback1.success.length == response.length); - assert (callback1.failure == null); + assertEquals(responseSize, callback1.successLength); + assertNull(callback1.failure); } } @@ -191,7 +198,10 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } @@ -218,9 +228,10 @@ public StreamManager getStreamManager() { synchronized (callback0) { // not complete yet, but should complete soon - assert (callback0.success == null && callback0.failure == null); + assertEquals(-1, callback0.successLength); + assertNull(callback0.failure); callback0.wait(2 * 1000); - assert (callback0.failure instanceof IOException); + assertTrue(callback0.failure instanceof IOException); } synchronized (callback1) { @@ -235,13 +246,13 @@ public StreamManager getStreamManager() { */ class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { - byte[] success; + int successLength = -1; Throwable failure; @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { synchronized(this) { - success = response; + successLength = response.remaining(); this.notifyAll(); } } @@ -258,7 +269,7 @@ public void onFailure(Throwable e) { public void onSuccess(int chunkIndex, ManagedBuffer buffer) { synchronized(this) { try { - success = buffer.nioByteBuffer().array(); + successLength = buffer.nioByteBuffer().remaining(); this.notifyAll(); } catch (IOException e) { // weird diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 88fa2258bb79..9e9be98c140b 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -26,7 +27,6 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.base.Charsets; import com.google.common.collect.Sets; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -41,6 +41,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -55,11 +56,14 @@ public static void setUp() throws Exception { TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - String msg = new String(message, Charsets.UTF_8); + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + String msg = JavaUtils.bytesToString(message); String[] parts = msg.split("/"); if (parts[0].equals("hello")) { - callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + "!")); } else if (parts[0].equals("return error")) { callback.onFailure(new RuntimeException("Returned: " + parts[1])); } else if (parts[0].equals("throw error")) { @@ -68,9 +72,8 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } @Override - public void receive(TransportClient client, byte[] message) { - String msg = new String(message, Charsets.UTF_8); - oneWayMsgs.add(msg); + public void receive(TransportClient client, ByteBuffer message) { + oneWayMsgs.add(JavaUtils.bytesToString(message)); } @Override @@ -103,8 +106,9 @@ private RpcResult sendRPC(String ... commands) throws Exception { RpcResponseCallback callback = new RpcResponseCallback() { @Override - public void onSuccess(byte[] message) { - res.successMessages.add(new String(message, Charsets.UTF_8)); + public void onSuccess(ByteBuffer message) { + String response = JavaUtils.bytesToString(message); + res.successMessages.add(response); sem.release(); } @@ -116,7 +120,7 @@ public void onFailure(Throwable e) { }; for (String command : commands) { - client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + client.sendRpc(JavaUtils.stringToBytes(command), callback); } if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { @@ -173,7 +177,7 @@ public void sendOneWayMessage() throws Exception { final String message = "no reply"; TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.send(message.getBytes(Charsets.UTF_8)); + client.send(JavaUtils.stringToBytes(message)); assertEquals(0, client.getHandler().numOutstandingRequests()); // Make sure the message arrives. diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 538f3efe8d6f..9c49556927f0 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -116,7 +116,10 @@ public ManagedBuffer openStream(String streamId) { }; RpcHandler handler = new RpcHandler() { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { throw new UnsupportedOperationException(); } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 30144f4a9fc7..128f7cba7435 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; + import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -27,6 +29,7 @@ import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; @@ -42,7 +45,7 @@ public class TransportResponseHandlerSuite { @Test - public void handleSuccessfulFetch() { + public void handleSuccessfulFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); @@ -56,7 +59,7 @@ public void handleSuccessfulFetch() { } @Test - public void handleFailedFetch() { + public void handleFailedFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); @@ -69,7 +72,7 @@ public void handleFailedFetch() { } @Test - public void clearAllOutstandingRequests() { + public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); handler.addFetchRequest(new StreamChunkId(1, 0), callback); @@ -88,23 +91,24 @@ public void clearAllOutstandingRequests() { } @Test - public void handleSuccessfulRPC() { + public void handleSuccessfulRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); assertEquals(1, handler.numOutstandingRequests()); - handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored + // This response should be ignored. + handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7)))); assertEquals(1, handler.numOutstandingRequests()); - byte[] arr = new byte[10]; - handler.handle(new RpcResponse(12345, arr)); - verify(callback, times(1)).onSuccess(eq(arr)); + ByteBuffer resp = ByteBuffer.allocate(10); + handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp))); + verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10))); assertEquals(0, handler.numOutstandingRequests()); } @Test - public void handleFailedRPC() { + public void handleFailedRPC() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.addRpcRequest(12345, callback); @@ -119,7 +123,7 @@ public void handleFailedRPC() { } @Test - public void testActiveStreams() { + public void testActiveStreams() throws Exception { Channel c = new LocalChannel(); c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); TransportResponseHandler handler = new TransportResponseHandler(c); diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index a6f180bc40c9..751516b9d82a 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -22,7 +22,7 @@ import java.io.File; import java.lang.reflect.Method; -import java.nio.charset.StandardCharsets; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import java.util.Random; @@ -57,6 +57,7 @@ import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -123,39 +124,53 @@ public void testNonMatching() { } @Test - public void testSaslAuthentication() throws Exception { + public void testSaslAuthentication() throws Throwable { testBasicSasl(false); } @Test - public void testSaslEncryption() throws Exception { + public void testSaslEncryption() throws Throwable { testBasicSasl(true); } - private void testBasicSasl(boolean encrypt) throws Exception { + private void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocation) { - byte[] message = (byte[]) invocation.getArguments()[1]; + ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", new String(message, StandardCharsets.UTF_8)); - cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8)); + assertEquals("Ping", JavaUtils.bytesToString(message)); + cb.onSuccess(JavaUtils.stringToBytes("Pong")); return null; } }) .when(rpcHandler) - .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class)); + .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { - byte[] response = ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); - assertEquals("Pong", new String(response, StandardCharsets.UTF_8)); + ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); + assertEquals("Pong", JavaUtils.bytesToString(response)); } finally { ctx.close(); // There should be 2 terminated events; one for the client, one for the server. - verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + Throwable error = null; + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + while (deadline > System.nanoTime()) { + try { + verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + error = null; + break; + } catch (Throwable t) { + error = t; + TimeUnit.MILLISECONDS.sleep(10); + } + } + if (error != null) { + throw error; + } } } @@ -325,8 +340,8 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { SaslTestCtx ctx = null; try { ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); - ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), - TimeUnit.SECONDS.toMillis(10)); + ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), + TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); } catch (Exception e) { assertFalse(e.getCause() instanceof TimeoutException); diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index 19475c21ffce..d4de4a941d48 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -118,6 +118,27 @@ public Void answer(InvocationOnMock in) { } } + @Test + public void testSplitLengthField() throws Exception { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + ByteBuf buf = Unpooled.buffer(frame.length + 8); + buf.writeLong(frame.length + 8); + buf.writeBytes(frame); + + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(); + try { + decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); + verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + decoder.channelRead(ctx, buf); + verify(ctx).fireChannelRead(any(ByteBuf.class)); + assertEquals(0, buf.refCnt()); + } finally { + decoder.channelInactive(ctx); + release(buf); + } + } + @Test(expected = IllegalArgumentException.class) public void testNegativeFrameSize() throws Exception { testInvalidFrame(-1); @@ -183,7 +204,7 @@ private void testInvalidFrame(long size) throws Exception { try { decoder.channelRead(ctx, frame); } finally { - frame.release(); + release(frame); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 3ddf5c3c3918..f22187a01db0 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -19,6 +19,7 @@ import java.io.File; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.annotations.VisibleForTesting; @@ -66,8 +67,8 @@ public ExternalShuffleBlockHandler( } @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message); handleMessage(msgObj, client, callback); } @@ -85,13 +86,13 @@ protected void handleMessage( } long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; checkAuth(client, msg.appId); blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); - callback.onSuccess(new byte[0]); + callback.onSuccess(ByteBuffer.wrap(new byte[0])); } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ef3a9dcc8711..58ca87d9d3b1 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import com.google.common.base.Preconditions; @@ -139,7 +140,7 @@ public void registerWithShuffleServer( checkInit(); TransportClient client = clientFactory.createUnmanagedClient(host, port); try { - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); } finally { client.close(); diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index e653f5cb147e..1b2ddbf1ed91 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.nio.ByteBuffer; import java.util.Arrays; import org.slf4j.Logger; @@ -89,11 +90,11 @@ public void start() { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { try { - streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 7543b6be4f2a..675820308bd4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle.mesos; import java.io.IOException; +import java.nio.ByteBuffer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,11 +55,11 @@ public MesosExternalShuffleClient( public void registerDriverWithShuffleService(String host, int port) throws IOException { checkInit(); - byte[] registerDriver = new RegisterDriver(appId).toByteArray(); + ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer(); TransportClient client = clientFactory.createClient(host, port); client.sendRpc(registerDriver, new RpcResponseCallback() { @Override - public void onSuccess(byte[] response) { + public void onSuccess(ByteBuffer response) { logger.info("Successfully registered app " + appId + " with external shuffle service."); } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index fcb52363e632..7fbe3384b4d4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -17,6 +17,8 @@ package org.apache.spark.network.shuffle.protocol; +import java.nio.ByteBuffer; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -53,7 +55,7 @@ private Type(int id) { // NB: Java does not support static methods in interfaces, so we must put this in a static class. public static class Decoder { /** Deserializes the 'type' byte followed by the message itself. */ - public static BlockTransferMessage fromByteArray(byte[] msg) { + public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { ByteBuf buf = Unpooled.wrappedBuffer(msg); byte type = buf.readByte(); switch (type) { @@ -68,12 +70,12 @@ public static BlockTransferMessage fromByteArray(byte[] msg) { } /** Serializes the 'type' byte followed by the message itself. */ - public byte[] toByteArray() { + public ByteBuffer toByteBuffer() { // Allow room for encoded message, plus the type byte ByteBuf buf = Unpooled.buffer(encodedLength() + 1); buf.writeByte(type().id); encode(buf); assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); - return buf.array(); + return buf.nioBuffer(); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 1c2fa4d0d462..19c870aebb02 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.sasl; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.concurrent.atomic.AtomicReference; @@ -52,6 +53,7 @@ import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -107,8 +109,8 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); - assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS); + assertEquals(msg, JavaUtils.bytesToString(resp)); } @Test @@ -136,7 +138,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -144,7 +146,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); + client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -222,13 +224,13 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); + client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); - StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); + StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); long streamId = stream.streamId; // Create a second client, authenticated with a different app ID, and try to read from @@ -275,7 +277,7 @@ public synchronized void onFailure(int chunkIndex, Throwable t) { /** RPC handler which simply responds with the message it received. */ public static class TestRpcHandler extends RpcHandler { @Override - public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { callback.onSuccess(message); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index d65de9ca550a..86c8609e7070 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -36,7 +36,7 @@ public void serializeOpenShuffleBlocks() { } private void checkSerializeDeserialize(BlockTransferMessage msg) { - BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); assertEquals(msg, msg2); assertEquals(msg.hashCode(), msg2.hashCode()); assertEquals(msg.toString(), msg2.toString()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e61390cf5706..9379412155e8 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -60,12 +60,12 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); + ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); handler.receive(client, registerMessage, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); - verify(callback, times(1)).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } @SuppressWarnings("unchecked") @@ -77,17 +77,18 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + .toByteBuffer(); handler.receive(client, openBlocks, callback); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); - ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); StreamHandle handle = - (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); assertEquals(2, handle.numChunks); @SuppressWarnings("unchecked") @@ -104,7 +105,7 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; + ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); try { handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); @@ -112,7 +113,7 @@ public void testBadMessages() { // pass } - byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); + ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); try { handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); @@ -120,7 +121,7 @@ public void testBadMessages() { // pass } - verify(callback, never()).onSuccess((byte[]) any()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, never()).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index b35a6d685dd0..2590b9ce4c1f 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -134,14 +134,14 @@ private BlockFetchingListener fetchBlocks(final LinkedHashMap() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( - (byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( + (ByteBuffer) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } - }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); // Respond to each chunk request with a single buffer from our blocks array. final AtomicInteger expectedChunkIndex = new AtomicInteger(0); From 96bf468c7860be317c20ccacf259910968d2dc83 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 30 Nov 2015 17:33:09 -0800 Subject: [PATCH 1605/1824] [SPARK-12049][CORE] User JVM shutdown hook can cause deadlock at shutdown Avoid potential deadlock with a user app's shutdown hook thread by more narrowly synchronizing access to 'hooks' Author: Sean Owen Closes #10042 from srowen/SPARK-12049. --- .../spark/util/ShutdownHookManager.scala | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 4012dca3ecdf..620f226a23e1 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -206,7 +206,7 @@ private[spark] object ShutdownHookManager extends Logging { private [util] class SparkShutdownHookManager { private val hooks = new PriorityQueue[SparkShutdownHook]() - private var shuttingDown = false + @volatile private var shuttingDown = false /** * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not @@ -232,28 +232,27 @@ private [util] class SparkShutdownHookManager { } } - def runAll(): Unit = synchronized { + def runAll(): Unit = { shuttingDown = true - while (!hooks.isEmpty()) { - Try(Utils.logUncaughtExceptions(hooks.poll().run())) + var nextHook: SparkShutdownHook = null + while ({ nextHook = hooks.synchronized { hooks.poll() }; nextHook != null }) { + Try(Utils.logUncaughtExceptions(nextHook.run())) } } - def add(priority: Int, hook: () => Unit): AnyRef = synchronized { - checkState() - val hookRef = new SparkShutdownHook(priority, hook) - hooks.add(hookRef) - hookRef - } - - def remove(ref: AnyRef): Boolean = synchronized { - hooks.remove(ref) + def add(priority: Int, hook: () => Unit): AnyRef = { + hooks.synchronized { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } } - private def checkState(): Unit = { - if (shuttingDown) { - throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") - } + def remove(ref: AnyRef): Boolean = { + hooks.synchronized { hooks.remove(ref) } } } From f73379be2b0c286957b678a996cb56afc96015eb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 30 Nov 2015 17:15:47 -0800 Subject: [PATCH 1606/1824] [HOTFIX][SPARK-12000] Add missing quotes in Jekyll API docs plugin. I accidentally omitted these as part of #10049. --- docs/_plugins/copy_api_dirs.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index f2f3e2e65314..174c202e3791 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -117,7 +117,7 @@ puts "Moving to python/docs directory and building sphinx." cd("../python/docs") - system(make html) || raise("Python doc generation failed") + system("make html") || raise("Python doc generation failed") puts "Moving back into home dir." cd("../../") From 9693b0d5a55bc1d2da96f04fe2c6de59a8dfcc1b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 30 Nov 2015 20:56:42 -0800 Subject: [PATCH 1607/1824] [SPARK-12018][SQL] Refactor common subexpression elimination code JIRA: https://issues.apache.org/jira/browse/SPARK-12018 The code of common subexpression elimination can be factored and simplified. Some unnecessary variables can be removed. Author: Liang-Chi Hsieh Closes #10009 from viirya/refactor-subexpr-eliminate. --- .../sql/catalyst/expressions/Expression.scala | 10 ++---- .../expressions/codegen/CodeGenerator.scala | 34 +++++-------------- .../codegen/GenerateUnsafeProjection.scala | 4 +-- 3 files changed, 14 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 169435a10ea2..b55d3653a715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -94,13 +94,9 @@ abstract class Expression extends TreeNode[Expression] { def gen(ctx: CodeGenContext): GeneratedExpressionCode = { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added - // as a function, `subExprState.fnName`. Just call that. - val code = - s""" - |/* $this */ - |${subExprState.fnName}(${ctx.INPUT_ROW}); - """.stripMargin.trim - GeneratedExpressionCode(code, subExprState.code.isNull, subExprState.code.value) + // as a function and called in advance. Just use it. + val code = s"/* $this */" + GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2f3d6aeb86c5..440c7d2fc115 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -104,16 +104,13 @@ class CodeGenContext { val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // State used for subexpression elimination. - case class SubExprEliminationState( - isLoaded: String, - code: GeneratedExpressionCode, - fnName: String) + case class SubExprEliminationState(isNull: String, value: String) // Foreach expression that is participating in subexpression elimination, the state to use. val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] - // The collection of isLoaded variables that need to be reset on each row. - val subExprIsLoadedVariables = mutable.ArrayBuffer.empty[String] + // The collection of sub-exression result resetting methods that need to be called on each row. + val subExprResetVariables = mutable.ArrayBuffer.empty[String] final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" @@ -408,7 +405,6 @@ class CodeGenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach(e => { val expr = e.head - val isLoaded = freshName("isLoaded") val isNull = freshName("isNull") val value = freshName("value") val fnName = freshName("evalExpr") @@ -417,18 +413,12 @@ class CodeGenContext { val code = expr.gen(this) val fn = s""" - |private void $fnName(InternalRow ${INPUT_ROW}) { - | if (!$isLoaded) { - | ${code.code.trim} - | $isLoaded = true; - | $isNull = ${code.isNull}; - | $value = ${code.value}; - | } + |private void $fnName(InternalRow $INPUT_ROW) { + | ${code.code.trim} + | $isNull = ${code.isNull}; + | $value = ${code.value}; |} """.stripMargin - code.code = fn - code.isNull = isNull - code.value = value addNewFunction(fnName, fn) @@ -448,18 +438,12 @@ class CodeGenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - - // Maintain the loaded value and isNull as member variables. This is necessary if the codegen - // function is split across multiple functions. - // TODO: maintaining this as a local variable probably allows the compiler to do better - // optimizations. - addMutableState("boolean", isLoaded, s"$isLoaded = false;") addMutableState("boolean", isNull, s"$isNull = false;") addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subExprIsLoadedVariables += isLoaded - val state = SubExprEliminationState(isLoaded, code, fnName) + subExprResetVariables += s"$fnName($INPUT_ROW);" + val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7b6c9373ebe3..68005afb21d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -287,8 +287,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") - // Reset the isLoaded flag for each row. - val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = false;" }.mkString("\n") + // Reset the subexpression values for each row. + val subexprReset = ctx.subExprResetVariables.mkString("\n") val code = s""" From a0af0e351e45a8be47a6f65efd132eaa4a00c9e4 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 1 Dec 2015 09:26:58 +0000 Subject: [PATCH 1608/1824] [SPARK-11898][MLLIB] Use broadcast for the global tables in Word2Vec jira: https://issues.apache.org/jira/browse/SPARK-11898 syn0Global and sync1Global in word2vec are quite large objects with size (vocab * vectorSize * 8), yet they are passed to worker using basic task serialization. Use broadcast can greatly improve the performance. My benchmark shows that, for 1M vocabulary and default vectorSize 100, changing to broadcast can help, 1. decrease the worker memory consumption by 45%. 2. decrease running time by 40%. This will also help extend the upper limit for Word2Vec. Author: Yuhao Yang Closes #9878 from hhbyyh/w2vBC. --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a47f27b0afb1..655ac0bb5545 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -316,12 +316,15 @@ class Word2Vec extends Serializable with Logging { Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = learningRate + for (k <- 1 to numIterations) { + val bcSyn0Global = sc.broadcast(syn0Global) + val bcSyn1Global = sc.broadcast(syn1Global) val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount @@ -405,6 +408,8 @@ class Word2Vec extends Serializable with Logging { } i += 1 } + bcSyn0Global.unpersist(false) + bcSyn1Global.unpersist(false) } newSentences.unpersist() From c87531b765f8934a9a6c0f673617e0abfa5e5f0e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Dec 2015 07:42:37 -0800 Subject: [PATCH 1609/1824] [SPARK-11949][SQL] Set field nullable property for GroupingSets to get correct results for null values JIRA: https://issues.apache.org/jira/browse/SPARK-11949 The result of cube plan uses incorrect schema. The schema of cube result should set nullable property to true because the grouping expressions will have null values. Author: Liang-Chi Hsieh Closes #10038 from viirya/fix-cube. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 ++++++++-- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 94ffbbb2e5c6..b8f212fca750 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -223,6 +223,11 @@ class Analyzer( case other => Alias(other, other.toString)() } + // TODO: We need to use bitmasks to determine which grouping expressions need to be + // set as nullable. For example, if we have GROUPING SETS ((a,b), a), we do not need + // to change the nullability of a. + val attributeMap = groupByAliases.map(a => (a -> a.toAttribute.withNullability(true))).toMap + val aggregations: Seq[NamedExpression] = x.aggregations.map { // If an expression is an aggregate (contains a AggregateExpression) then we dont change // it so that the aggregation is computed on the unmodified value of its argument @@ -231,12 +236,13 @@ class Analyzer( // If not then its a grouping expression and we need to use the modified (with nulls from // Expand) value of the expression. case expr => expr.transformDown { - case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e) + case e => + groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) }.asInstanceOf[NamedExpression] } val child = Project(x.child.output ++ groupByAliases, x.child) - val groupByAttributes = groupByAliases.map(_.toAttribute) + val groupByAttributes = groupByAliases.map(attributeMap(_)) Aggregate( groupByAttributes :+ VirtualColumn.groupingIdAttribute, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b5c636d0de1d..b1004bc5bc29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType +case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -86,6 +87,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, 2013, 78000.0) :: Row(null, null, 113000.0) :: Nil ) + + val df0 = sqlContext.sparkContext.parallelize(Seq( + Fact(20151123, 18, 35, "room1", 18.6), + Fact(20151123, 18, 35, "room2", 22.4), + Fact(20151123, 18, 36, "room1", 17.4), + Fact(20151123, 18, 36, "room2", 25.6))).toDF() + + val cube0 = df0.cube("date", "hour", "minute", "room_name").agg(Map("temp" -> "avg")) + assert(cube0.where("date IS NULL").count > 0) } test("rollup overlapping columns") { From 1401166576c7018c5f9c31e0a6703d5fb16ea339 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 1 Dec 2015 09:45:55 -0800 Subject: [PATCH 1610/1824] [SPARK-12060][CORE] Avoid memory copy in JavaSerializerInstance.serialize `JavaSerializerInstance.serialize` uses `ByteArrayOutputStream.toByteArray` to get the serialized data. `ByteArrayOutputStream.toByteArray` needs to copy the content in the internal array to a new array. However, since the array will be converted to `ByteBuffer` at once, we can avoid the memory copy. This PR added `ByteBufferOutputStream` to access the protected `buf` and convert it to a `ByteBuffer` directly. Author: Shixiong Zhu Closes #10051 from zsxwing/SPARK-12060. --- .../spark/serializer/JavaSerializer.scala | 7 ++--- .../spark/util/ByteBufferOutputStream.scala | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index b463a71d5bd7..ea718a0edbe7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,8 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -96,11 +95,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() + val bos = new ByteBufferOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - ByteBuffer.wrap(bos.toByteArray) + bos.toByteBuffer } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala new file mode 100644 index 000000000000..92e45224db81 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer + +/** + * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer + */ +private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { + + def toByteBuffer: ByteBuffer = { + return ByteBuffer.wrap(buf, 0, count) + } +} From 69dbe6b40df35d488d4ee343098ac70d00bbdafb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Dec 2015 10:21:31 -0800 Subject: [PATCH 1611/1824] [SPARK-12046][DOC] Fixes various ScalaDoc/JavaDoc issues This PR backports PR #10039 to master Author: Cheng Lian Closes #10063 from liancheng/spark-12046.doc-fix.master. --- .../spark/api/java/function/Function4.java | 2 +- .../spark/api/java/function/VoidFunction.java | 2 +- .../api/java/function/VoidFunction2.java | 2 +- .../apache/spark/api/java/JavaPairRDD.scala | 16 +++---- .../org/apache/spark/memory/package.scala | 14 +++--- .../org/apache/spark/rdd/CoGroupedRDD.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 6 +-- .../org/apache/spark/rdd/ShuffledRDD.scala | 2 +- .../org/apache/spark/scheduler/Task.scala | 5 ++- .../serializer/SerializationDebugger.scala | 13 +++--- .../scala/org/apache/spark/util/Vector.scala | 1 + .../util/collection/ExternalSorter.scala | 30 ++++++------- .../WritablePartitionedPairCollection.scala | 7 +-- .../streaming/kinesis/KinesisReceiver.scala | 23 +++++----- .../streaming/kinesis/KinesisUtils.scala | 13 +++--- .../mllib/optimization/GradientDescent.scala | 12 ++--- project/SparkBuild.scala | 2 + .../scala/org/apache/spark/sql/Column.scala | 11 ++--- .../spark/streaming/StreamingContext.scala | 11 ++--- .../streaming/dstream/FileInputDStream.scala | 19 ++++---- .../streaming/receiver/BlockGenerator.scala | 22 ++++----- .../scheduler/ReceiverSchedulingPolicy.scala | 45 ++++++++++--------- .../util/FileBasedWriteAheadLog.scala | 7 +-- .../spark/streaming/util/RecurringTimer.scala | 8 ++-- .../org/apache/spark/deploy/yarn/Client.scala | 10 ++--- 25 files changed, 152 insertions(+), 133 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java index fd727d64863d..9c35a22ca9d0 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function4.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -23,5 +23,5 @@ * A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R. */ public interface Function4 extends Serializable { - public R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; + R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java index 2a10435b7523..f30d42ee5796 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java @@ -23,5 +23,5 @@ * A function with no return value. */ public interface VoidFunction extends Serializable { - public void call(T t) throws Exception; + void call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java index 6c576ab67845..da9ae1c9c5cd 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -23,5 +23,5 @@ * A two-argument function that takes arguments of type T1 and T2 with no return value. */ public interface VoidFunction2 extends Serializable { - public void call(T1 v1, T2 v2) throws Exception; + void call(T1 v1, T2 v2) throws Exception; } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 0b0c6e5bb8cc..87deaf20e2b2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -215,13 +215,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an + * "combined type" C. Note that V and C can be different -- for example, one might group an * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple @@ -247,13 +247,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an + * "combined type" C. Note that V and C can be different -- for example, one might group an * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. diff --git a/core/src/main/scala/org/apache/spark/memory/package.scala b/core/src/main/scala/org/apache/spark/memory/package.scala index 564e30d2ffd6..3d00cd9cb637 100644 --- a/core/src/main/scala/org/apache/spark/memory/package.scala +++ b/core/src/main/scala/org/apache/spark/memory/package.scala @@ -21,13 +21,13 @@ package org.apache.spark * This package implements Spark's memory management system. This system consists of two main * components, a JVM-wide memory manager and a per-task manager: * - * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. - * This component implements the policies for dividing the available memory across tasks and for - * allocating memory between storage (memory used caching and data transfer) and execution (memory - * used by computations, such as shuffles, joins, sorts, and aggregations). - * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual tasks. - * Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide - * MemoryManager. + * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. + * This component implements the policies for dividing the available memory across tasks and for + * allocating memory between storage (memory used caching and data transfer) and execution + * (memory used by computations, such as shuffles, joins, sorts, and aggregations). + * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual + * tasks. Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide + * MemoryManager. * * Internally, each of these components have additional abstractions for memory bookkeeping: * diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 935c3babd8ea..3a0ca1d81329 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -70,7 +70,7 @@ private[spark] class CoGroupPartition( * * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of * instantiating this directly. - + * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output */ diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c6181902ace6..44d195587a08 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -65,9 +65,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Note that V and C can be different -- for example, one might group an RDD of type * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index a013c3f66a3a..3ef506e1562b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -86,7 +86,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) } - override def getPreferredLocations(partition: Partition): Seq[String] = { + override protected def getPreferredLocations(partition: Partition): Seq[String] = { val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] tracker.getPreferredLocationsForShuffle(dep, partition.index) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 4fb32ba8cb18..2fcd5aa57d11 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -33,8 +33,9 @@ import org.apache.spark.util.Utils /** * A unit of execution. We have two kinds of Task's in Spark: - * - [[org.apache.spark.scheduler.ShuffleMapTask]] - * - [[org.apache.spark.scheduler.ResultTask]] + * + * - [[org.apache.spark.scheduler.ShuffleMapTask]] + * - [[org.apache.spark.scheduler.ResultTask]] * * A Spark job consists of one or more stages. The very last stage in a job consists of multiple * ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index a1b1e1631eaf..e2951d8a3e09 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -53,12 +53,13 @@ private[spark] object SerializationDebugger extends Logging { /** * Find the path leading to a not serializable object. This method is modeled after OpenJDK's * serialization mechanism, and handles the following cases: - * - primitives - * - arrays of primitives - * - arrays of non-primitive objects - * - Serializable objects - * - Externalizable objects - * - writeReplace + * + * - primitives + * - arrays of primitives + * - arrays of non-primitive objects + * - Serializable objects + * - Externalizable objects + * - writeReplace * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index 2ed827eab46d..6b3fa8491904 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -122,6 +122,7 @@ class Vector(val elements: Array[Double]) extends Serializable { override def toString: String = elements.mkString("(", ", ", ")") } +@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") object Vector { def apply(elements: Array[Double]): Vector = new Vector(elements) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 2440139ac95e..44b1d90667e6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -67,24 +67,24 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * * At a high level, this class works internally as follows: * - * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if - * we want to combine by key, or a PartitionedPairBuffer if we don't. - * Inside these buffers, we sort elements by partition ID and then possibly also by key. - * To avoid calling the partitioner multiple times with each key, we store the partition ID - * alongside each record. + * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if + * we want to combine by key, or a PartitionedPairBuffer if we don't. + * Inside these buffers, we sort elements by partition ID and then possibly also by key. + * To avoid calling the partitioner multiple times with each key, we store the partition ID + * alongside each record. * - * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first - * by partition ID and possibly second by key or by hash code of the key, if we want to do - * aggregation. For each file, we track how many objects were in each partition in memory, so we - * don't have to write out the partition ID for every element. + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first + * by partition ID and possibly second by key or by hash code of the key, if we want to do + * aggregation. For each file, we track how many objects were in each partition in memory, so we + * don't have to write out the partition ID for every element. * - * - When the user requests an iterator or file output, the spilled files are merged, along with - * any remaining in-memory data, using the same sort order defined above (unless both sorting - * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering - * from the ordering parameter, or read the keys with the same hash code and compare them with - * each other for equality to merge values. + * - When the user requests an iterator or file output, the spilled files are merged, along with + * any remaining in-memory data, using the same sort order defined above (unless both sorting + * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering + * from the ordering parameter, or read the keys with the same hash code and compare them with + * each other for equality to merge values. * - * - Users are expected to call stop() at the end to delete all the intermediate files. + * - Users are expected to call stop() at the end to delete all the intermediate files. */ private[spark] class ExternalSorter[K, V, C]( context: TaskContext, diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 38848e9018c6..5232c2bd8d6f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -23,9 +23,10 @@ import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that - * - Have an associated partition for each key-value pair. - * - Support a memory-efficient sorted iterator - * - Support a WritablePartitionedIterator for writing the contents directly as bytes. + * + * - Have an associated partition for each key-value pair. + * - Support a memory-efficient sorted iterator + * - Support a WritablePartitionedIterator for writing the contents directly as bytes. */ private[spark] trait WritablePartitionedPairCollection[K, V] { /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 97dbb918573a..05080835fc4a 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -46,17 +46,18 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * https://github.com/awslabs/amazon-kinesis-client * * The way this Receiver works is as follows: - * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple - * KinesisRecordProcessor - * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is - * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. - * - When the block generator defines a block, then the recorded sequence number ranges that were - * inserted into the block are recorded separately for being used later. - * - When the block is ready to be pushed, the block is pushed and the ranges are reported as - * metadata of the block. In addition, the ranges are used to find out the latest sequence - * number for each shard that can be checkpointed through the DynamoDB. - * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence - * number for it own shard. + * + * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple + * KinesisRecordProcessor + * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is + * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. + * - When the block generator defines a block, then the recorded sequence number ranges that were + * inserted into the block are recorded separately for being used later. + * - When the block is ready to be pushed, the block is pushed and the ranges are reported as + * metadata of the block. In addition, the ranges are used to find out the latest sequence + * number for each shard that can be checkpointed through the DynamoDB. + * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence + * number for it own shard. * * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2849fd8a8210..2de6195716e5 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -226,12 +226,13 @@ object KinesisUtils { * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * * Note: - * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets AWS credentials. - * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. - * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in - * [[org.apache.spark.SparkConf]]. + * + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name + * in [[org.apache.spark.SparkConf]]. * * @param ssc StreamingContext object * @param streamName Kinesis stream name diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 3b663b5defb0..37bb6f6097f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -81,11 +81,13 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va * Set the convergence tolerance. Default 0.001 * convergenceTol is a condition which decides iteration termination. * The end of iteration is decided based on below logic. - * - If the norm of the new solution vector is >1, the diff of solution vectors - * is compared to relative tolerance which means normalizing by the norm of - * the new solution vector. - * - If the norm of the new solution vector is <=1, the diff of solution vectors - * is compared to absolute tolerance which is not normalizing. + * + * - If the norm of the new solution vector is >1, the diff of solution vectors + * is compared to relative tolerance which means normalizing by the norm of + * the new solution vector. + * - If the norm of the new solution vector is <=1, the diff of solution vectors + * is compared to absolute tolerance which is not normalizing. + * * Must be between 0.0 and 1.0 inclusively. */ def setConvergenceTol(tolerance: Double): this.type = { diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 63290d8a666e..b1dcaedcba75 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -535,6 +535,8 @@ object Unidoc { .map(_.filterNot(_.getName.contains("$"))) .map(_.filterNot(_.getCanonicalPath.contains("akka"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/examples"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/memory"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/shuffle"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/executor"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b3cd9e1eff14..ad6af481fadc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -136,11 +136,12 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** * Extracts a value or values from a complex type. * The following types of extraction are supported: - * - Given an Array, an integer ordinal can be used to retrieve a single value. - * - Given a Map, a key of the correct type can be used to retrieve an individual value. - * - Given a Struct, a string fieldName can be used to extract that field. - * - Given an Array of Structs, a string fieldName can be used to extract filed - * of every struct in that array, and return an Array of fields + * + * - Given an Array, an integer ordinal can be used to retrieve a single value. + * - Given a Map, a key of the correct type can be used to retrieve an individual value. + * - Given a Struct, a string fieldName can be used to extract that field. + * - Given an Array of Structs, a string fieldName can be used to extract filed + * of every struct in that array, and return an Array of fields * * @group expr_ops * @since 1.4.0 diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index aee172a4f549..6fb8ad38abce 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -574,11 +574,12 @@ class StreamingContext private[streaming] ( * :: DeveloperApi :: * * Return the current state of the context. The context can be in three possible states - - * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. - * Input DStreams, transformations and output operations can be created on the context. - * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. - * Input DStreams, transformations and output operations cannot be created on the context. - * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. + * + * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * Input DStreams, transformations and output operations can be created on the context. + * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. + * Input DStreams, transformations and output operations cannot be created on the context. + * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. */ @DeveloperApi def getState(): StreamingContextState = synchronized { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 40208a64861f..cb5b1f252e90 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * class remembers the information about the files selected in past batches for * a certain duration (say, "remember window") as shown in the figure below. * + * {{{ * |<----- remember window ----->| * ignore threshold --->| |<--- current batch time * |____.____.____.____.____.____| @@ -49,6 +50,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * ---------------------|----|----|----|----|----|----|-----------------------> Time * |____|____|____|____|____|____| * remembered batches + * }}} * * The trailing end of the window is the "ignore threshold" and all files whose mod times * are less than this threshold are assumed to have already been selected and are therefore @@ -59,14 +61,15 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti * `isNewFile` for more details. * * This makes some assumptions from the underlying file system that the system is monitoring. - * - The clock of the file system is assumed to synchronized with the clock of the machine running - * the streaming app. - * - If a file is to be visible in the directory listings, it must be visible within a certain - * duration of the mod time of the file. This duration is the "remember window", which is set to - * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be - * selected as the mod time will be less than the ignore threshold when it becomes visible. - * - Once a file is visible, the mod time cannot change. If it does due to appends, then the - * processing semantics are undefined. + * + * - The clock of the file system is assumed to synchronized with the clock of the machine running + * the streaming app. + * - If a file is to be visible in the directory listings, it must be visible within a certain + * duration of the mod time of the file. This duration is the "remember window", which is set to + * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be + * selected as the mod time will be less than the ignore threshold when it becomes visible. + * - Once a file is visible, the mod time cannot change. If it does due to appends, then the + * processing semantics are undefined. */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 421d60ae359f..cc7c04bfc9f6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -84,13 +84,14 @@ private[streaming] class BlockGenerator( /** * The BlockGenerator can be in 5 possible states, in the order as follows. - * - Initialized: Nothing has been started - * - Active: start() has been called, and it is generating blocks on added data. - * - StoppedAddingData: stop() has been called, the adding of data has been stopped, - * but blocks are still being generated and pushed. - * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but - * they are still being pushed. - * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + * + * - Initialized: Nothing has been started + * - Active: start() has been called, and it is generating blocks on added data. + * - StoppedAddingData: stop() has been called, the adding of data has been stopped, + * but blocks are still being generated and pushed. + * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but + * they are still being pushed. + * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. */ private object GeneratorState extends Enumeration { type GeneratorState = Value @@ -125,9 +126,10 @@ private[streaming] class BlockGenerator( /** * Stop everything in the right order such that all the data added is pushed out correctly. - * - First, stop adding data to the current buffer. - * - Second, stop generating blocks. - * - Finally, wait for queue of to-be-pushed blocks to be drained. + * + * - First, stop adding data to the current buffer. + * - Second, stop generating blocks. + * - Finally, wait for queue of to-be-pushed blocks to be drained. */ def stop(): Unit = { // Set the state to stop adding data diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index 234bc8660da8..391a461f0812 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -27,28 +27,29 @@ import org.apache.spark.streaming.receiver.Receiver * A class that tries to schedule receivers with evenly distributed. There are two phases for * scheduling receivers. * - * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule - * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. - * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker should - * update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. - * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list - * that contains the scheduled locations. Then when a receiver is starting, it will send a - * register request and `ReceiverTracker.registerReceiver` will be called. In - * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should check - * if the location of this receiver is one of the scheduled locations, if not, the register will - * be rejected. - * - The second phase is local scheduling when a receiver is restarting. There are two cases of - * receiver restarting: - * - If a receiver is restarting because it's rejected due to the real location and the scheduled - * locations mismatching, in other words, it fails to start in one of the locations that - * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that are - * still alive in the list of scheduled locations, then use them to launch the receiver job. - * - If a receiver is restarting without a scheduled locations list, or the executors in the list - * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` should - * not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it should clear - * it. Then when this receiver is registering, we can know this is a local scheduling, and - * `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if the launching - * location is matching. + * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule + * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. + * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker + * should update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. + * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list + * that contains the scheduled locations. Then when a receiver is starting, it will send a + * register request and `ReceiverTracker.registerReceiver` will be called. In + * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should + * check if the location of this receiver is one of the scheduled locations, if not, the register + * will be rejected. + * - The second phase is local scheduling when a receiver is restarting. There are two cases of + * receiver restarting: + * - If a receiver is restarting because it's rejected due to the real location and the scheduled + * locations mismatching, in other words, it fails to start in one of the locations that + * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that + * are still alive in the list of scheduled locations, then use them to launch the receiver + * job. + * - If a receiver is restarting without a scheduled locations list, or the executors in the list + * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` + * should not set `ReceiverTrackingInfo.scheduledLocations` for this receiver, instead, it + * should clear it. Then when this receiver is registering, we can know this is a local + * scheduling, and `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if + * the launching location is matching. * * In conclusion, we should make a global schedule, try to achieve that exactly as long as possible, * otherwise do local scheduling. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index f5165f7c3912..a99b57083583 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -34,9 +34,10 @@ import org.apache.spark.{Logging, SparkConf} /** * This class manages write ahead log files. - * - Writes records (bytebuffers) to periodically rotating log files. - * - Recovers the log files and the reads the recovered records upon failures. - * - Cleans up old log files. + * + * - Writes records (bytebuffers) to periodically rotating log files. + * - Recovers the log files and the reads the recovered records upon failures. + * - Cleans up old log files. * * Uses [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]] to write * and [[org.apache.spark.streaming.util.FileBasedWriteAheadLogReader]] to read. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 0148cb51c6f0..bfb53614050a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -72,10 +72,10 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: /** * Stop the timer, and return the last time the callback was made. - * - interruptTimer = true will interrupt the callback - * if it is in progress (not guaranteed to give correct time in this case). - * - interruptTimer = false guarantees that there will be at least one callback after `stop` has - * been called. + * + * @param interruptTimer True will interrupt the callback if it is in progress (not guaranteed to + * give correct time in this case). False guarantees that there will be at + * least one callback after `stop` has been called. */ def stop(interruptTimer: Boolean): Long = synchronized { if (!stopped) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index a77a3e2420e2..f0590d2d222e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1336,11 +1336,11 @@ object Client extends Logging { * * This method uses two configuration values: * - * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may - * only be valid in the gateway node. - * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may - * contain, for example, env variable references, which will be expanded by the NMs when - * starting containers. + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. * * If either config is not available, the input path is returned. */ From 8ddc55f1d582cccc3ca135510b2ea776e889e481 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:22:55 -0800 Subject: [PATCH 1612/1824] [SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail The reason is that, for a single culumn `RowEncoder`(or a single field product encoder), when we use it as the encoder for grouping key, we should also combine the grouping attributes, although there is only one grouping attribute. Author: Wenchen Fan Closes #10059 from cloud-fan/bug. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 7 ++++--- .../org/apache/spark/sql/DatasetSuite.scala | 19 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 6 +++--- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index da4600133290..c357f88a94dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -70,7 +70,7 @@ class Dataset[T] private[sql]( * implicit so that we can use it when constructing new [[Dataset]] objects that have the same * object type (that will be possibly resolved to a different schema). */ - private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) + private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index a10a89342fb5..4bf0b256fcb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -228,10 +228,11 @@ class GroupedDataset[K, V] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedVEncoder, dataAttributes).named) - val keyColumn = if (groupingAttributes.length > 1) { - Alias(CreateStruct(groupingAttributes), "key")() - } else { + val keyColumn = if (resolvedKEncoder.flat) { + assert(groupingAttributes.length == 1) groupingAttributes.head + } else { + Alias(CreateStruct(groupingAttributes), "key")() } val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7d539180ded9..a2c8d201563e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -272,6 +272,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3 -> "abcxyz", 5 -> "hello") } + test("groupBy single field class, count") { + val ds = Seq("abc", "xyz", "hello").toDS() + val count = ds.groupBy(s => Tuple1(s.length)).count() + + checkAnswer( + count, + (Tuple1(3), 2L), (Tuple1(5), 1L) + ) + } + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") @@ -282,6 +292,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } + test("groupBy columns, count") { + val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() + val count = ds.groupBy($"_1").count() + + checkAnswer( + count, + (Row("a"), 2L), (Row("b"), 1L)) + } + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").keyAs[String] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 6ea1fe4ccfd8..8f476dd0f99b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -64,12 +64,12 @@ abstract class QueryTest extends PlanTest { * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead * which performs a subset of the checks done by this function. */ - protected def checkAnswer[T : Encoder]( - ds: => Dataset[T], + protected def checkAnswer[T]( + ds: Dataset[T], expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } From 9df24624afedd993a39ab46c8211ae153aedef1a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:24:53 -0800 Subject: [PATCH 1613/1824] [SPARK-11856][SQL] add type cast if the real type is different but compatible with encoder schema When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff and lost the required data type, which may lead to runtime error if the real type doesn't match the encoder's schema. For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type is `[a: int, b: long]`, then we will hit runtime error and say that we can't construct class `Data` with int and long, because we lost the information that `b` should be a string. Author: Wenchen Fan Closes #9840 from cloud-fan/err-msg. --- .../spark/sql/catalyst/ScalaReflection.scala | 93 +++++++-- .../sql/catalyst/analysis/Analyzer.scala | 40 ++++ .../catalyst/analysis/HiveTypeCoercion.scala | 2 +- .../catalyst/encoders/ExpressionEncoder.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 9 + .../expressions/complexTypeCreator.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 12 ++ .../encoders/EncoderResolutionSuite.scala | 180 ++++++++++++++++++ .../spark/sql/DatasetAggregatorSuite.scala | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 21 +- 10 files changed, 335 insertions(+), 32 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d133ad3f0d89..9b6b5b8bd1a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -117,31 +116,75 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None) + def constructorFor[T : TypeTag]: Expression = { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + constructorFor(tpe, None, walkedTypePath) + } private def constructorFor( tpe: `Type`, - path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { + path: Option[Expression], + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetStructField(p, ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + def addToPathOrdinal( + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => GetStructField(p, ordinal)) + .getOrElse(BoundReference(ordinal, dataType, false)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + def getPath: Expression = { + val dataType = schemaFor(tpe).dataType + if (path.isDefined) { + path.get + } else { + upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + } + } + + /** + * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * don't need to cast struct type because there must be `UnresolvedExtractValue` or + * `GetStructField` wrapping it, thus we only need to handle leaf type. + */ + def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _ => UpCast(expr, expected, walkedTypePath) + } tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - WrapOption(constructorFor(optType, path)) + val className = getClassNameFromType(optType) + val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath + WrapOption(constructorFor(optType, path, newTypePath)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -219,9 +262,11 @@ object ScalaReflection extends ScalaReflection { primitiveMethod.map { method => Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath Invoke( MapObjects( - p => constructorFor(elementType, Some(p)), + p => constructorFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -230,10 +275,12 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val arrayData = Invoke( MapObjects( - p => constructorFor(elementType, Some(p)), + p => constructorFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -246,12 +293,13 @@ object ScalaReflection extends ScalaReflection { arrayData :: Nil) case t if t <:< localTypeOf[Map[_, _]] => + // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p)), + p => constructorFor(keyType, Some(p), walkedTypePath), Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), schemaFor(keyType).dataType), "array", @@ -260,7 +308,7 @@ object ScalaReflection extends ScalaReflection { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p)), + p => constructorFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), schemaFor(valueType).dataType), "array", @@ -297,12 +345,19 @@ object ScalaReflection extends ScalaReflection { val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val dataType = schemaFor(fieldType).dataType - + val clsName = getClassNameFromType(fieldType) + val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. if (cls.getName startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + constructorFor( + fieldType, + Some(addToPathOrdinal(i, dataType, newTypePath)), + newTypePath) } else { - constructorFor(fieldType, Some(addToPath(fieldName))) + constructorFor( + fieldType, + Some(addToPath(fieldName, dataType, newTypePath)), + newTypePath) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b8f212fca750..765327c474e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -72,6 +72,7 @@ class Analyzer( ResolveReferences :: ResolveGroupingAnalytics :: ResolvePivot :: + ResolveUpCast :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -1182,3 +1183,42 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } } } + +/** + * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. + */ +object ResolveUpCast extends Rule[LogicalPlan] { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { + throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " + + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + toPrecedence > 0 && fromPrecedence > toPrecedence + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case u @ UpCast(child, _, _) if !child.resolved => u + + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => + fail(child, to, walkedTypePath) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => + fail(child, to, walkedTypePath) + case (from, to) if illegalNumericPrecedence(from, to) => + fail(child, to, walkedTypePath) + case (TimestampType, DateType) => + fail(child, DateType, walkedTypePath) + case (StringType, to: NumericType) => + fail(child, to, walkedTypePath) + case _ => Cast(child, dataType) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index f90fc3cc1218..29502a59915f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -53,7 +53,7 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - private val numericPrecedence = + private[sql] val numericPrecedence = IndexedSeq( ByteType, ShortType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0c10a56c555f..06ffe864552f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.{StructField, ObjectType, StructType} @@ -235,12 +236,13 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) + val optimizedPlan = SimplifyCasts(analyzedPlan) // In order to construct instances of inner classes (for example those declared in a REPL cell), // we need an instance of the outer scope. This rule substitues those outer objects into // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` // registry. - copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform { case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => val outer = outerScopes.get(n.cls.getDeclaringClass.getName) if (outer == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a2c6c39fd8ce..cb60d5958d53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -914,3 +914,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { """ } } + +/** + * Cast the child expression to the target data type, but will throw error if the cast might + * truncate, e.g. long -> int, timestamp -> data. + */ +case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String]) + extends UnaryExpression with Unevaluable { + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 1854dfaa7db3..72cc89c8be91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { case class CreateNamedStruct(children: Seq[Expression]) extends Expression { /** - * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this + * Returns Aliased [[Expression]]s that could be used to construct a flattened version of this * StructType. */ def flatten: Seq[NamedExpression] = valExprs.zip(names).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0cd352d0fa92..ce45245b9f6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -90,6 +90,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { case _ => false } + /** + * Returns whether this DecimalType is tighter than `other`. If yes, it means `this` + * can be casted into `other` safely without losing any precision or range. + */ + private[sql] def isTighterThan(other: DataType): Boolean = other match { + case dt: DecimalType => + (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale + case dt: IntegralType => + isTighterThan(DecimalType.forType(dt)) + case _ => false + } + /** * The default size of a value of the DecimalType is 4096 bytes. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala new file mode 100644 index 000000000000..0289988342e7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ + +case class StringLongClass(a: String, b: Long) + +case class StringIntClass(a: String, b: Int) + +case class ComplexClass(a: Long, b: StringLongClass) + +class EncoderResolutionSuite extends PlanTest { + test("real type doesn't match encoder schema but they are compatible: product") { + val encoder = ExpressionEncoder[StringLongClass] + val cls = classOf[StringLongClass] + + { + val attrs = Seq('a.string, 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + { + val attrs = Seq('a.int, 'b.long) + val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression + val expected = NewInstance( + cls, + toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + } + + test("real type doesn't match encoder schema but they are compatible: nested product") { + val encoder = ExpressionEncoder[ComplexClass] + val innerCls = classOf[StringLongClass] + val cls = classOf[ComplexClass] + + val structType = new StructType().add("a", IntegerType).add("b", LongType) + val attrs = Seq('a.int, 'b.struct(structType)) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + cls, + Seq( + 'a.int.cast(LongType), + If( + 'b.struct(structType).isNull, + Literal.create(null, ObjectType(innerCls)), + NewInstance( + innerCls, + Seq( + toExternalString( + GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)), + GetStructField('b.struct(structType), 1, Some("b"))), + false, + ObjectType(innerCls)) + )), + false, + ObjectType(cls)) + compareExpressions(fromRowExpr, expected) + } + + test("real type doesn't match encoder schema but they are compatible: tupled encoder") { + val encoder = ExpressionEncoder.tuple( + ExpressionEncoder[StringLongClass], + ExpressionEncoder[Long]) + val cls = classOf[StringLongClass] + + val structType = new StructType().add("a", StringType).add("b", ByteType) + val attrs = Seq('a.struct(structType), 'b.int) + val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression + val expected: Expression = NewInstance( + classOf[Tuple2[_, _]], + Seq( + NewInstance( + cls, + Seq( + toExternalString(GetStructField('a.struct(structType), 0, Some("a"))), + GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)), + false, + ObjectType(cls)), + 'b.int.cast(LongType)), + false, + ObjectType(classOf[Tuple2[_, _]])) + compareExpressions(fromRowExpr, expected) + } + + private def toExternalString(e: Expression): Expression = { + Invoke(e, "toString", ObjectType(classOf[String]), Nil) + } + + test("throw exception if real type is not compatible with encoder schema") { + val msg1 = intercept[AnalysisException] { + ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + }.message + assert(msg1 == + s""" + |Cannot up cast `b` from bigint to int as it may truncate + |The type path of the target object is: + |- field (class: "scala.Int", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + + val msg2 = intercept[AnalysisException] { + val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) + ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + }.message + assert(msg2 == + s""" + |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate + |The type path of the target object is: + |- field (class: "scala.Long", name: "b") + |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b") + |- root class: "org.apache.spark.sql.catalyst.encoders.ComplexClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + + // test for leaf types + castSuccess[Int, Long] + castSuccess[java.sql.Date, java.sql.Timestamp] + castSuccess[Long, String] + castSuccess[Int, java.math.BigDecimal] + castSuccess[Long, java.math.BigDecimal] + + castFail[Long, Int] + castFail[java.sql.Timestamp, java.sql.Date] + castFail[java.math.BigDecimal, Double] + castFail[Double, java.math.BigDecimal] + castFail[java.math.BigDecimal, Int] + castFail[String, Long] + + + private def castSuccess[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { + to.resolve(from.schema.toAttributes, null) + } + } + + private def castFail[T: TypeTag, U: TypeTag]: Unit = { + val from = ExpressionEncoder[T] + val to = ExpressionEncoder[U] + val catalystType = from.schema.head.dataType.simpleString + test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { + intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 19dce5d1e2f3..c6d2bf07b280 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -131,9 +131,9 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkAnswer( ds.groupBy(_._1).agg( sum(_._2), - expr("sum(_2)").as[Int], + expr("sum(_2)").as[Long], count("*")), - ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) + ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L)) } test("typed aggregation: complex case") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a2c8d201563e..542e4d6c43b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -335,24 +335,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int]), - ("a", 30), ("b", 3), ("c", 1)) + ds.groupBy(_._1).agg(sum("_2").as[Long]), + ("a", 30L), ("b", 3L), ("c", 1L)) } test("typed aggregation: expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]), - ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L)) + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), + ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) } test("typed aggregation: expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]), - ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L)) + ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), + ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) } test("typed aggregation: expr, expr, expr, expr") { @@ -360,11 +360,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds.groupBy(_._1).agg( - sum("_2").as[Int], + sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*").as[Long], avg("_2").as[Double]), - ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0)) + ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0)) } test("cogroup") { @@ -476,6 +476,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ((nullInt, "1"), (new java.lang.Integer(22), "2")), ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) } + + test("change encoder with compatible schema") { + val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] + assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) + } } From fd95eeaf491809c6bb0f83d46b37b5e2eebbcbca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 10:35:12 -0800 Subject: [PATCH 1614/1824] [SPARK-11954][SQL] Encoder for JavaBeans create java version of `constructorFor` and `extractorFor` in `JavaTypeInference` Author: Wenchen Fan This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #9937 from cloud-fan/pojo. --- .../scala/org/apache/spark/sql/Encoder.scala | 18 + .../sql/catalyst/JavaTypeInference.scala | 313 +++++++++++++++++- .../catalyst/encoders/ExpressionEncoder.scala | 21 +- .../sql/catalyst/expressions/objects.scala | 42 ++- .../spark/sql/catalyst/trees/TreeNode.scala | 27 +- .../sql/catalyst/util/ArrayBasedMapData.scala | 5 + .../sql/catalyst/util/GenericArrayData.scala | 3 + .../sql/catalyst/trees/TreeNodeSuite.scala | 25 ++ .../apache/spark/sql/JavaDatasetSuite.java | 174 +++++++++- 9 files changed, 608 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 03aa25eda807..c40061ae0aaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -97,6 +97,24 @@ object Encoders { */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal + * - time related: java.sql.Date, java.sql.Timestamp + * - collection types: only array and java.util.List currently, map support is in progress + * - nested java bean. + * + * @since 1.6.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7d4cfbe6faec..c8ee87e8819f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -17,14 +17,20 @@ package org.apache.spark.sql.catalyst -import java.beans.Introspector +import java.beans.{PropertyDescriptor, Introspector} import java.lang.{Iterable => JIterable} -import java.util.{Iterator => JIterator, Map => JMap} +import java.util.{Iterator => JIterator, Map => JMap, List => JList} import scala.language.existentials import com.google.common.reflect.TypeToken + import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.unsafe.types.UTF8String + /** * Type-inference utilities for POJOs and Java collections. @@ -33,13 +39,14 @@ object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val listType = TypeToken.of(classOf[JList[_]]) private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType /** - * Infers the corresponding SQL data type of a JavaClean class. + * Infers the corresponding SQL data type of a JavaBean class. * @param beanClass Java type * @return (SQL data type, nullable) */ @@ -58,6 +65,8 @@ object JavaTypeInference { (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) @@ -87,15 +96,14 @@ object JavaTypeInference { (ArrayType(dataType, nullable), true) case _ if mapType.isAssignableFrom(typeToken) => - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] - val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) - val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) - val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyType, valueType) = mapKeyValueType(typeToken) val (keyDataType, _) = inferDataType(keyType) val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) case _ => + // TODO: we should only collect properties that have getter and setter. However, some tests + // pass in scala case class as java bean class which doesn't have getter and setter. val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") val fields = properties.map { property => @@ -107,11 +115,294 @@ object JavaTypeInference { } } + private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + val beanInfo = Introspector.getBeanInfo(beanClass) + beanInfo.getPropertyDescriptors + .filter(p => p.getReadMethod != null && p.getWriteMethod != null) + } + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] - val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) - val iteratorType = iterableSupertype.resolveType(iteratorReturnType) - val itemType = iteratorType.resolveType(nextReturnType) - itemType + val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSuperType.resolveType(iteratorReturnType) + iteratorType.resolveType(nextReturnType) + } + + private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSuperType.resolveType(keySetReturnType)) + val valueType = elementType(mapSuperType.resolveType(valuesReturnType)) + keyType -> valueType + } + + /** + * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping + * to a native type, an ObjectType is returned. + * + * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers. + */ + private def inferExternalType(cls: Class[_]): DataType = cls match { + case c if c == java.lang.Boolean.TYPE => BooleanType + case c if c == java.lang.Byte.TYPE => ByteType + case c if c == java.lang.Short.TYPE => ShortType + case c if c == java.lang.Integer.TYPE => IntegerType + case c if c == java.lang.Long.TYPE => LongType + case c if c == java.lang.Float.TYPE => FloatType + case c if c == java.lang.Double.TYPE => DoubleType + case c if c == classOf[Array[Byte]] => BinaryType + case _ => ObjectType(cls) + } + + /** + * Returns an expression that can be used to construct an object of java bean `T` given an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + */ + def constructorFor(beanClass: Class[_]): Expression = { + constructorFor(TypeToken.of(beanClass), None) + } + + private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + + typeToken.getRawType match { + case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath + + case c if c == classOf[java.lang.Short] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Integer] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Long] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Double] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Byte] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Float] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Boolean] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case c if c == classOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case c if c.isArray => + val elementType = c.getComponentType + val primitiveMethod = elementType match { + case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") + case c if c == java.lang.Byte.TYPE => Some("toByteArray") + case c if c == java.lang.Short.TYPE => Some("toShortArray") + case c if c == java.lang.Integer.TYPE => Some("toIntArray") + case c if c == java.lang.Long.TYPE => Some("toLongArray") + case c if c == java.lang.Float.TYPE => Some("toFloatArray") + case c if c == java.lang.Double.TYPE => Some("toDoubleArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, ObjectType(c)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(typeToken.getComponentType, Some(p)), + getPath, + inferDataType(elementType)._1), + "array", + ObjectType(c)) + } + + case c if listType.isAssignableFrom(typeToken) => + val et = elementType(typeToken) + val array = + Invoke( + MapObjects( + p => constructorFor(et, Some(p)), + getPath, + inferDataType(et)._1), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + + case _ if mapType.isAssignableFrom(typeToken) => + val (keyType, valueType) = mapKeyValueType(typeToken) + val keyDataType = inferDataType(keyType)._1 + val valueDataType = inferDataType(valueType)._1 + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(keyDataType)), + keyDataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueDataType)), + valueDataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[JMap[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil) + + case other => + val properties = getJavaBeanProperties(other) + assert(properties.length > 0) + + val setters = properties.map { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName))) + }.toMap + + val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) + val result = InitializeJavaBean(newInstance, setters) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(other)), + result + ) + } else { + result + } + } + } + + /** + * Returns expressions for extracting all the fields from the given type. + */ + def extractorsFor(beanClass: Class[_]): CreateNamedStruct = { + val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) + extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + } + + private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { + + def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { + val (dataType, nullable) = inferDataType(elementType) + if (ScalaReflection.isNativeType(dataType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType)) + } + } + + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + typeToken.getRawType match { + case c if c == classOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case c if c == classOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case c if c == classOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + case c if c == classOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case c if c == classOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case c if c == classOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case c if c == classOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + + case _ if typeToken.isArray => + toCatalystArray(inputObject, typeToken.getComponentType) + + case _ if listType.isAssignableFrom(typeToken) => + toCatalystArray(inputObject, elementType(typeToken)) + + case _ if mapType.isAssignableFrom(typeToken) => + // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can + // not guarantee they have same iteration order(which is different from scala map). + // A possible solution is creating a new `MapObjects` that can iterate a map directly. + throw new UnsupportedOperationException("map type is not supported currently") + + case other => + val properties = getJavaBeanProperties(other) + if (properties.length > 0) { + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + } else { + throw new UnsupportedOperationException(s"no encoder found for ${other.getName}") + } + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 06ffe864552f..3e8420ecb9cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -29,8 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection} import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** @@ -68,6 +67,22 @@ object ExpressionEncoder { ClassTag[T](cls)) } + // TODO: improve error message for java bean encoder. + def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { + val schema = JavaTypeInference.inferDataType(beanClass)._1 + assert(schema.isInstanceOf[StructType]) + + val toRowExpression = JavaTypeInference.extractorsFor(beanClass) + val fromRowExpression = JavaTypeInference.constructorFor(beanClass) + + new ExpressionEncoder[T]( + schema.asInstanceOf[StructType], + flat = false, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](beanClass)) + } + /** * Given a set of N encoders, constructs a new encoder that produce objects as items in an * N-tuple. Note that these encoders should be unresolved so that information about @@ -216,7 +231,7 @@ case class ExpressionEncoder[T]( */ def assertUnresolved(): Unit = { (fromRowExpression +: toRowExpressions).foreach(_.foreach { - case a: AttributeReference => + case a: AttributeReference if a.name != "loopVar" => sys.error(s"Unresolved encoder expected, but $a was found.") case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 62d09f0f5510..e6ab9a31be59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -346,7 +346,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext * as an ArrayType. This is similar to a typical map operation, but where the lambda function * is expressed using catalyst expressions. * - * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData + * The following collection ObjectTypes are currently supported: + * Seq, Array, ArrayData, java.util.List * * @param function A function that returns an expression, given an attribute that can be used * to access the current value. This is does as a lambda function so that @@ -386,6 +387,8 @@ case class MapObjects( (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => (".length", (i: String) => s"[$i]", false) + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + (".size()", (i: String) => s".get($i)", false) case ArrayType(t, _) => val (sqlType, primitiveElement) = t match { case m: MapType => (m, false) @@ -596,3 +599,40 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B override def dataType: DataType = ObjectType(tag.runtimeClass) } + +/** + * Initialize a Java Bean instance by setting its field values via setters. + */ +case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression]) + extends Expression { + + override def nullable: Boolean = beanInstance.nullable + override def children: Seq[Expression] = beanInstance +: setters.values.toSeq + override def dataType: DataType = beanInstance.dataType + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val instanceGen = beanInstance.gen(ctx) + + val initialize = setters.map { + case (setterMethod, fieldValue) => + val fieldGen = fieldValue.gen(ctx) + s""" + ${fieldGen.code} + ${instanceGen.value}.$setterMethod(${fieldGen.value}); + """ + } + + ev.isNull = instanceGen.isNull + ev.value = instanceGen.value + + s""" + ${instanceGen.code} + if (!${instanceGen.isNull}) { + ${initialize.mkString("\n")} + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 35f087baccde..f1cea07976a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.trees +import scala.collection.Map + import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.types.{StructType, DataType} @@ -191,6 +193,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case nonChild: AnyRef => nonChild case null => null } + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.view.force // `mapValues` is lazy and we need to force it to materialize case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) @@ -262,7 +277,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { Some(arg) } - case m: Map[_, _] => m + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 70b028d2b3f7..d85b72ed83de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -70,4 +70,9 @@ object ArrayBasedMapData { def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { keys.zip(values).toMap } + + def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = { + import scala.collection.JavaConverters._ + keys.zip(values).toMap.asJava + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 96588bb5dc1b..2b8cdc1e23ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, Decimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -24,6 +26,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seq: Seq[Any]) = this(seq.toArray) + def this(list: java.util.List[Any]) = this(list.asScala) // TODO: This is boxing. We should specialize. def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 8fff39906b34..965bdb1515e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } +case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { + override def children: Seq[Expression] = map.values.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite { val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) assert(expected === actual) } + + test("expressions inside a map") { + val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2))) + + { + val actual = expression.transform { + case Literal(i: Int, _) => Literal(i + 1) + } + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + + { + val actual = expression.withNewChildren(Seq(Literal(2), Literal(3))) + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 67a3190cb7d4..ae47f4fe0e23 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -31,14 +31,15 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.GroupedDataset; +import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.catalyst.encoders.OuterScopes; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDatasetSuite implements Serializable { private transient JavaSparkContext jsc; @@ -506,4 +507,169 @@ public void testJavaEncoderErrorMessageForPrivateClass() { public void testKryoEncoderErrorMessageForPrivateClass() { Encoders.kryo(PrivateClassTest.class); } + + public class SimpleJavaBean implements Serializable { + private boolean a; + private int b; + private byte[] c; + private String[] d; + private List e; + private List f; + + public boolean isA() { + return a; + } + + public void setA(boolean a) { + this.a = a; + } + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public byte[] getC() { + return c; + } + + public void setC(byte[] c) { + this.c = c; + } + + public String[] getD() { + return d; + } + + public void setD(String[] d) { + this.d = d; + } + + public List getE() { + return e; + } + + public void setE(List e) { + this.e = e; + } + + public List getF() { + return f; + } + + public void setF(List f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean that = (SimpleJavaBean) o; + + if (a != that.a) return false; + if (b != that.b) return false; + if (!Arrays.equals(c, that.c)) return false; + if (!Arrays.equals(d, that.d)) return false; + if (!e.equals(that.e)) return false; + return f.equals(that.f); + } + + @Override + public int hashCode() { + int result = (a ? 1 : 0); + result = 31 * result + b; + result = 31 * result + Arrays.hashCode(c); + result = 31 * result + Arrays.hashCode(d); + result = 31 * result + e.hashCode(); + result = 31 * result + f.hashCode(); + return result; + } + } + + public class NestedJavaBean implements Serializable { + private SimpleJavaBean a; + + public SimpleJavaBean getA() { + return a; + } + + public void setA(SimpleJavaBean a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NestedJavaBean that = (NestedJavaBean) o; + + return a.equals(that.a); + } + + @Override + public int hashCode() { + return a.hashCode(); + } + } + + @Test + public void testJavaBeanEncoder() { + OuterScopes.addOuterScope(this); + SimpleJavaBean obj1 = new SimpleJavaBean(); + obj1.setA(true); + obj1.setB(3); + obj1.setC(new byte[]{1, 2}); + obj1.setD(new String[]{"hello", null}); + obj1.setE(Arrays.asList("a", "b")); + obj1.setF(Arrays.asList(100L, null, 200L)); + SimpleJavaBean obj2 = new SimpleJavaBean(); + obj2.setA(false); + obj2.setB(30); + obj2.setC(new byte[]{3, 4}); + obj2.setD(new String[]{null, "world"}); + obj2.setE(Arrays.asList("x", "y")); + obj2.setF(Arrays.asList(300L, null, 400L)); + + List data = Arrays.asList(obj1, obj2); + Dataset ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds.collectAsList()); + + NestedJavaBean obj3 = new NestedJavaBean(); + obj3.setA(obj1); + + List data2 = Arrays.asList(obj3); + Dataset ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Assert.assertEquals(data2, ds2.collectAsList()); + + Row row1 = new GenericRow(new Object[]{ + true, + 3, + new byte[]{1, 2}, + new String[]{"hello", null}, + Arrays.asList("a", "b"), + Arrays.asList(100L, null, 200L)}); + Row row2 = new GenericRow(new Object[]{ + false, + 30, + new byte[]{3, 4}, + new String[]{null, "world"}, + Arrays.asList("x", "y"), + Arrays.asList(300L, null, 400L)}); + StructType schema = new StructType() + .add("a", BooleanType, false) + .add("b", IntegerType, false) + .add("c", BinaryType) + .add("d", createArrayType(StringType)) + .add("e", createArrayType(StringType)) + .add("f", createArrayType(LongType)); + Dataset ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + .as(Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds3.collectAsList()); + } } From 0a7bca2da04aefff16f2513ec27a92e69ceb77f6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 1 Dec 2015 10:38:59 -0800 Subject: [PATCH 1615/1824] [SPARK-11905][SQL] Support Persist/Cache and Unpersist in Dataset APIs Persist and Unpersist exist in both RDD and Dataframe APIs. I think they are still very critical in Dataset APIs. Not sure if my understanding is correct? If so, could you help me check if the implementation is acceptable? Please provide your opinions. marmbrus rxin cloud-fan Thank you very much! Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #9889 from gatorsmile/persistDS. --- .../org/apache/spark/sql/DataFrame.scala | 9 +++ .../scala/org/apache/spark/sql/Dataset.scala | 50 +++++++++++- .../org/apache/spark/sql/SQLContext.scala | 9 +++ .../spark/sql/execution/CacheManager.scala | 27 ++++--- .../apache/spark/sql/DatasetCacheSuite.scala | 80 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 5 +- 6 files changed, 162 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6197f10813a3..eb8700369275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1584,6 +1584,7 @@ class DataFrame private[sql]( def distinct(): DataFrame = dropDuplicates() /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ @@ -1593,12 +1594,17 @@ class DataFrame private[sql]( } /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ def cache(): this.type = persist() /** + * Persist this [[DataFrame]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. * @group basic * @since 1.3.0 */ @@ -1608,6 +1614,8 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. * @group basic * @since 1.3.0 */ @@ -1617,6 +1625,7 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c357f88a94dd..d6bb1d2ad8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils /** @@ -565,7 +566,7 @@ class Dataset[T] private[sql]( * combined. * * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analagous to `UNION ALL` in SQL. + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. * @since 1.6.0 */ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) @@ -618,7 +619,6 @@ class Dataset[T] private[sql]( case _ => Alias(CreateStruct(rightOutput), "_2")() } - implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) withPlan[(T, U)](other) { (left, right) => @@ -697,11 +697,55 @@ class Dataset[T] private[sql]( */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def cache(): this.type = persist() + + /** + * Persist this [[Dataset]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. + * @group basic + * @since 1.6.0 + */ + def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. + * @since 1.6.0 + */ + def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @since 1.6.0 + */ + def unpersist(): this.type = unpersist(blocking = false) + /* ******************** * * Internal Functions * * ******************** */ - private[sql] def logicalPlan = queryExecution.analyzed + private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9cc65de19180..4e2625086837 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -338,6 +338,15 @@ class SQLContext private[sql]( cacheManager.lookupCachedData(table(tableName)).nonEmpty } + /** + * Returns true if the [[Queryable]] is currently cached in-memory. + * @group cachemgmt + * @since 1.3.0 + */ + private[sql] def isCached(qName: Queryable): Boolean = { + cacheManager.lookupCachedData(qName).nonEmpty + } + /** * Caches the specified table in-memory. * @group cachemgmt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 293fcfe96e67..50f6562815c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel @@ -75,12 +74,12 @@ private[sql] class CacheManager extends Logging { } /** - * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike - * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing - * the in-memory columnar representation of the underlying table is expensive. + * Caches the data produced by the logical representation of the given [[Queryable]]. + * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because + * recomputing the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: DataFrame, + query: Queryable, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -95,13 +94,13 @@ private[sql] class CacheManager extends Logging { sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - sqlContext.executePlan(query.logicalPlan).executedPlan, + sqlContext.executePlan(planToCache).executedPlan, tableName)) } } - /** Removes the data for the given [[DataFrame]] from the cache */ - private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[Queryable]] from the cache */ + private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -109,9 +108,11 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ + /** Tries to remove the data for the given [[Queryable]] from the cache + * if it's cached + */ private[sql] def tryUncacheQuery( - query: DataFrame, + query: Queryable, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -123,12 +124,12 @@ private[sql] class CacheManager extends Logging { found } - /** Optionally returns cached data for the given [[DataFrame]] */ - private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[Queryable]] */ + private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } - /** Optionally returns cached data for the given LogicalPlan. */ + /** Optionally returns cached data for the given [[LogicalPlan]]. */ private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { cachedData.find(cd => plan.sameResult(cd.plan)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala new file mode 100644 index 000000000000..3a283a4e1f61 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +class DatasetCacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("persist and unpersist") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val cached = ds.cache() + // count triggers the caching action. It should not throw. + cached.count() + // Make sure, the Dataset is indeed cached. + assertCached(cached) + // Check result. + checkAnswer( + cached, + 2, 3, 4) + // Drop the cache. + cached.unpersist() + assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + } + + test("persist and then rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + ds1.persist() + assertCached(ds1) + ds2.persist() + assertCached(ds2) + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + assertCached(joined, 2) + + ds1.unpersist() + assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + ds2.unpersist() + assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + } + + test("persist and then groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + agged.persist() + + checkAnswer( + agged.filter(_._1 == "b"), + ("b", 3)) + assertCached(agged.filter(_._1 == "b")) + + ds.unpersist() + assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + agged.unpersist() + assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8f476dd0f99b..bc22fb8b7bdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.Queryable abstract class QueryTest extends PlanTest { @@ -163,9 +164,9 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + * Asserts that a given [[Queryable]] will be executed using the given number of cached results. */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached From 6a8cf80cc8ef435ec46138fa57325bda5d68f3ce Mon Sep 17 00:00:00 2001 From: woj-i Date: Tue, 1 Dec 2015 11:05:45 -0800 Subject: [PATCH 1616/1824] [SPARK-11821] Propagate Kerberos keytab for all environments andrewor14 the same PR as in branch 1.5 harishreedharan Author: woj-i Closes #9859 from woj-i/master. --- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 4 ++++ docs/running-on-yarn.md | 4 ++-- docs/sql-programming-guide.md | 7 ++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 2e912b59afdb..52d3ab34c178 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -545,6 +545,10 @@ object SparkSubmit { if (args.isPython) { sysProps.put("spark.yarn.isPython", "true") } + } + + // assure a keytab is available from any place in a JVM + if (clusterManager == YARN || clusterManager == LOCAL) { if (args.principal != null) { require(args.keytab != null, "Keytab must be specified when principal is specified") if (!new File(args.keytab).exists()) { diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 925a1e0ba6fc..06413f83c3a7 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -358,14 +358,14 @@ If you need a reference to the proper location to put log files in the YARN so t The full path to the file that contains the keytab for the principal specified above. This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. + for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) spark.yarn.principal (none) - Principal to be used to login to KDC, while running on secure HDFS. + Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d7b205c2fa0d..7b1d97baa382 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1614,7 +1614,8 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), + `hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the @@ -2028,7 +2029,7 @@ Beeline will ask you for a username and password. In non-secure mode, simply ent your machine and a blank password. For secure mode, please follow the instructions given in the [beeline documentation](https://cwiki.apache.org/confluence/display/Hive/HiveServer2+Clients). -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may also use the beeline script that comes with Hive. @@ -2053,7 +2054,7 @@ To start the Spark SQL CLI, run the following in the Spark directory: ./bin/spark-sql -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` and `hdfs-site.xml` files in `conf/`. You may run `./bin/spark-sql --help` for a complete list of all available options. From 34e7093c1131162b3aa05b65a19a633a0b5b633e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Dec 2015 11:49:20 -0800 Subject: [PATCH 1617/1824] [SPARK-12065] Upgrade Tachyon from 0.8.1 to 0.8.2 This commit upgrades the Tachyon dependency from 0.8.1 to 0.8.2. Author: Josh Rosen Closes #10054 from JoshRosen/upgrade-to-tachyon-0.8.2. --- core/pom.xml | 2 +- make-distribution.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 37e3f168ab37..61744bb5c7bf 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -270,7 +270,7 @@ org.tachyonproject tachyon-client - 0.8.1 + 0.8.2 org.apache.hadoop diff --git a/make-distribution.sh b/make-distribution.sh index 7b417fe7cf61..e64ceb802464 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.8.1" +TACHYON_VERSION="0.8.2" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="http://tachyon-project.org/downloads/files/${TACHYON_VERSION}/${TACHYON_TGZ}" From 2cef1cdfbb5393270ae83179b6a4e50c3cbf9e93 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 1 Dec 2015 12:59:53 -0800 Subject: [PATCH 1618/1824] [SPARK-12030] Fix Platform.copyMemory to handle overlapping regions. This bug was exposed as memory corruption in Timsort which uses copyMemory to copy large regions that can overlap. The prior implementation did not handle this case half the time and always copied forward, resulting in the data being corrupt. Author: Nong Li Closes #10068 from nongli/spark-12030. --- .../org/apache/spark/unsafe/Platform.java | 27 ++++++-- .../spark/unsafe/PlatformUtilSuite.java | 61 +++++++++++++++++++ 2 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 1c16da982923..0d6b215fe5aa 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -107,12 +107,27 @@ public static void freeMemory(long address) { public static void copyMemory( Object src, long srcOffset, Object dst, long dstOffset, long length) { - while (length > 0) { - long size = Math.min(length, UNSAFE_COPY_THRESHOLD); - _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); - length -= size; - srcOffset += size; - dstOffset += size; + // Check if dstOffset is before or after srcOffset to determine if we should copy + // forward or backwards. This is necessary in case src and dst overlap. + if (dstOffset < srcOffset) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } else { + srcOffset += length; + dstOffset += length; + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + srcOffset -= size; + dstOffset -= size; + _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + } + } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java new file mode 100644 index 000000000000..693ec6ec58db --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import org.junit.Assert; +import org.junit.Test; + +public class PlatformUtilSuite { + + @Test + public void overlappingCopyMemory() { + byte[] data = new byte[3 * 1024 * 1024]; + int size = 2 * 1024 * 1024; + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + + Platform.copyMemory(data, Platform.BYTE_ARRAY_OFFSET, data, Platform.BYTE_ARRAY_OFFSET, size); + for (int i = 0; i < data.length; ++i) { + Assert.assertEquals((byte)i, data[i]); + } + + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET + 1, + data, + Platform.BYTE_ARRAY_OFFSET, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)(i + 1), data[i]); + } + + for (int i = 0; i < data.length; ++i) { + data[i] = (byte)i; + } + Platform.copyMemory( + data, + Platform.BYTE_ARRAY_OFFSET, + data, + Platform.BYTE_ARRAY_OFFSET + 1, + size); + for (int i = 0; i < size; ++i) { + Assert.assertEquals((byte)i, data[i + 1]); + } + } +} From 60b541ee1b97c9e5e84aa2af2ce856f316ad22b3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Dec 2015 14:08:36 -0800 Subject: [PATCH 1619/1824] [SPARK-12004] Preserve the RDD partitioner through RDD checkpointing The solution is the save the RDD partitioner in a separate file in the RDD checkpoint directory. That is, `/_partitioner`. In most cases, whether the RDD partitioner was recovered or not, does not affect the correctness, only reduces performance. So this solution makes a best-effort attempt to save and recover the partitioner. If either fails, the checkpointing is not affected. This makes this patch safe and backward compatible. Author: Tathagata Das Closes #9983 from tdas/SPARK-12004. --- .../spark/rdd/ReliableCheckpointRDD.scala | 122 +++++++++++++++++- .../spark/rdd/ReliableRDDCheckpointData.scala | 21 +-- .../org/apache/spark/CheckpointSuite.scala | 61 ++++++++- 3 files changed, 173 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index a69be6a068bb..fa71b8c26233 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -20,12 +20,12 @@ package org.apache.spark.rdd import java.io.IOException import scala.reflect.ClassTag +import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -33,8 +33,9 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} */ private[spark] class ReliableCheckpointRDD[T: ClassTag]( sc: SparkContext, - val checkpointPath: String) - extends CheckpointRDD[T](sc) { + val checkpointPath: String, + _partitioner: Option[Partitioner] = None + ) extends CheckpointRDD[T](sc) { @transient private val hadoopConf = sc.hadoopConfiguration @transient private val cpath = new Path(checkpointPath) @@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag]( /** * Return the path of the checkpoint directory this RDD reads data from. */ - override def getCheckpointFile: Option[String] = Some(checkpointPath) + override val getCheckpointFile: Option[String] = Some(checkpointPath) + + override val partitioner: Option[Partitioner] = { + _partitioner.orElse { + ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath) + } + } /** * Return partitions described by the files in the checkpoint directory. @@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging { "part-%05d".format(partitionIndex) } + private def checkpointPartitionerFileName(): String = { + "_partitioner" + } + + /** + * Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD. + */ + def writeRDDToCheckpointDirectory[T: ClassTag]( + originalRDD: RDD[T], + checkpointDir: String, + blockSize: Int = -1): ReliableCheckpointRDD[T] = { + + val sc = originalRDD.sparkContext + + // Create the output path for the checkpoint + val checkpointDirPath = new Path(checkpointDir) + val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration) + if (!fs.mkdirs(checkpointDirPath)) { + throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath") + } + + // Save to file, and reload it as an RDD + val broadcastedConf = sc.broadcast( + new SerializableConfiguration(sc.hadoopConfiguration)) + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) + sc.runJob(originalRDD, + writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _) + + if (originalRDD.partitioner.nonEmpty) { + writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) + } + + val newRDD = new ReliableCheckpointRDD[T]( + sc, checkpointDirPath.toString, originalRDD.partitioner) + if (newRDD.partitions.length != originalRDD.partitions.length) { + throw new SparkException( + s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + + s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})") + } + newRDD + } + /** - * Write this partition's values to a checkpoint file. + * Write a RDD partition's data to a checkpoint file. */ - def writeCheckpointFile[T: ClassTag]( + def writePartitionToCheckpointFile[T: ClassTag]( path: String, broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { @@ -151,6 +200,67 @@ private[spark] object ReliableCheckpointRDD extends Logging { } } + /** + * Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort + * basis; any exception while writing the partitioner is caught, logged and ignored. + */ + private def writePartitionerToCheckpointDir( + sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = { + try { + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + Utils.tryWithSafeFinally { + serializeStream.writeObject(partitioner) + } { + serializeStream.close() + } + logDebug(s"Written partitioner to $partitionerFilePath") + } catch { + case NonFatal(e) => + logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath") + } + } + + + /** + * Read a partitioner from the given RDD checkpoint directory, if it exists. + * This is done on a best-effort basis; any exception while reading the partitioner is + * caught, logged and ignored. + */ + private def readCheckpointedPartitionerFile( + sc: SparkContext, + checkpointDirPath: String): Option[Partitioner] = { + try { + val bufferSize = sc.conf.getInt("spark.buffer.size", 65536) + val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName) + val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(partitionerFilePath)) { + val fileInputStream = fs.open(partitionerFilePath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + val partitioner = Utils.tryWithSafeFinally[Partitioner] { + deserializeStream.readObject[Partitioner] + } { + deserializeStream.close() + } + logDebug(s"Read partitioner from $partitionerFilePath") + Some(partitioner) + } else { + logDebug("No partitioner file") + None + } + } catch { + case NonFatal(e) => + logWarning(s"Error reading partitioner from $checkpointDirPath, " + + s"partitioner will not be recovered which may lead to performance loss", e) + None + } + } + /** * Read the content of the specified checkpoint file. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 91cad6662e4d..cac6cbe780e9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v * This is called immediately after the first action invoked on this RDD has completed. */ protected override def doCheckpoint(): CheckpointRDD[T] = { - - // Create the output path for the checkpoint - val path = new Path(cpDir) - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException(s"Failed to create checkpoint path $cpDir") - } - - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) - // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) - rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) - val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) - if (newRDD.partitions.length != rdd.partitions.length) { - throw new SparkException( - s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " + - s"number of partitions from original RDD $rdd(${rdd.partitions.length})") - } + val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir) // Optionally clean our checkpoint files if the reference is out of scope if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { @@ -83,7 +65,6 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v } logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}") - newRDD } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index ab23326c6c25..553d46285ac0 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,7 +21,8 @@ import java.io.File import scala.reflect.ClassTag -import org.apache.spark.CheckpointSuite._ +import org.apache.hadoop.fs.Path + import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -74,8 +75,10 @@ trait RDDCheckpointTester { self: SparkFunSuite => // Test whether the checkpoint file has been created if (reliableCheckpoint) { - assert( - collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + assert(operatedRDD.getCheckpointFile.nonEmpty) + val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get) + assert(collectFunc(recoveredRDD) === result) + assert(recoveredRDD.partitioner === operatedRDD.partitioner) } // Test whether dependencies have been changed from its earlier parent RDD @@ -211,9 +214,14 @@ trait RDDCheckpointTester { self: SparkFunSuite => } /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ - protected def runTest(name: String)(body: Boolean => Unit): Unit = { + protected def runTest( + name: String, + skipLocalCheckpoint: Boolean = false + )(body: Boolean => Unit): Unit = { test(name + " [reliable checkpoint]")(body(true)) - test(name + " [local checkpoint]")(body(false)) + if (!skipLocalCheckpoint) { + test(name + " [local checkpoint]")(body(false)) + } } /** @@ -264,6 +272,49 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS assert(flatMappedRDD.collect() === result) } + runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean => + + def testPartitionerCheckpointing( + partitioner: Partitioner, + corruptPartitionerFile: Boolean = false + ): Unit = { + val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner) + rddWithPartitioner.checkpoint() + rddWithPartitioner.count() + assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty, + "checkpointing was not successful") + + if (corruptPartitionerFile) { + // Overwrite the partitioner file with garbage data + val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get) + val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration) + val partitionerFile = fs.listStatus(checkpointDir) + .find(_.getPath.getName.contains("partitioner")) + .map(_.getPath) + require(partitionerFile.nonEmpty, "could not find the partitioner file for testing") + val output = fs.create(partitionerFile.get, true) + output.write(100) + output.close() + } + + val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get) + assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered") + + if (!corruptPartitionerFile) { + assert(newRDD.partitioner != None, "partitioner not recovered") + assert(newRDD.partitioner === rddWithPartitioner.partitioner, + "recovered partitioner does not match") + } else { + assert(newRDD.partitioner == None, "partitioner unexpectedly recovered") + } + } + + testPartitionerCheckpointing(partitioner) + + // Test that corrupted partitioner file does not prevent recovery of RDD + testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true) + } + runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean => testRDD(_.map(x => x.toString), reliableCheckpoint) testRDD(_.flatMap(x => 1 to x), reliableCheckpoint) From 328b757d5d4486ea3c2e246780792d7a57ee85e5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 1 Dec 2015 15:13:10 -0800 Subject: [PATCH 1620/1824] Revert "[SPARK-12060][CORE] Avoid memory copy in JavaSerializerInstance.serialize" This reverts commit 1401166576c7018c5f9c31e0a6703d5fb16ea339. --- .../spark/serializer/JavaSerializer.scala | 7 +++-- .../spark/util/ByteBufferOutputStream.scala | 31 ------------------- 2 files changed, 4 insertions(+), 34 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index ea718a0edbe7..b463a71d5bd7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,7 +24,8 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} +import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.Utils private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -95,11 +96,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteBufferOutputStream() + val bos = new ByteArrayOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - bos.toByteBuffer + ByteBuffer.wrap(bos.toByteArray) } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala deleted file mode 100644 index 92e45224db81..000000000000 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.io.ByteArrayOutputStream -import java.nio.ByteBuffer - -/** - * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer - */ -private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { - - def toByteBuffer: ByteBuffer = { - return ByteBuffer.wrap(buf, 0, count) - } -} From e76431f886ae8061545b3216e8e2fb38c4db1f43 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 1 Dec 2015 15:21:53 -0800 Subject: [PATCH 1621/1824] [SPARK-11961][DOC] Add docs of ChiSqSelector https://issues.apache.org/jira/browse/SPARK-11961 Author: Xusen Yin Closes #9965 from yinxusen/SPARK-11961. --- docs/ml-features.md | 50 +++++++++++++ .../examples/ml/JavaChiSqSelectorExample.java | 71 +++++++++++++++++++ .../examples/ml/ChiSqSelectorExample.scala | 57 +++++++++++++++ 3 files changed, 178 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index cd1838d6d288..5f888775553f 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1949,3 +1949,53 @@ output.select("features", "label").show() {% endhighlight %} + +## ChiSqSelector + +`ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with +categorical features. ChiSqSelector orders features based on a +[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) +from the class, and then filters (selects) the top features which the class label depends on the +most. This is akin to yielding the features with the most predictive power. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `features`, and `clicked`, which is used as +our target to be predicted: + +~~~ +id | features | clicked +---|-----------------------|--------- + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 +~~~ + +If we use `ChiSqSelector` with a `numTopFeatures = 1`, then according to our label `clicked` the +last column in our `features` chosen as the most useful feature: + +~~~ +id | features | clicked | selectedFeatures +---|-----------------------|---------|------------------ + 7 | [0.0, 0.0, 18.0, 1.0] | 1.0 | [1.0] + 8 | [0.0, 1.0, 12.0, 0.0] | 0.0 | [0.0] + 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 | [0.1] +~~~ + +
    +
    + +Refer to the [ChiSqSelector Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ChiSqSelector) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala %} +
    + +
    + +Refer to the [ChiSqSelector Java docs](api/java/org/apache/spark/ml/feature/ChiSqSelector.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java %} +
    +
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java new file mode 100644 index 000000000000..ede05d6e2011 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.ChiSqSelector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaChiSqSelectorExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaChiSqSelectorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + RowFactory.create(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + RowFactory.create(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + ChiSqSelector selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures"); + + DataFrame result = selector.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala new file mode 100644 index 000000000000..a8d2bc4907e8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ChiSqSelector +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ChiSqSelectorExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("ChiSqSelectorExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Seq( + (7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), + (8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), + (9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) + ) + + val df = sc.parallelize(data).toDF("id", "features", "clicked") + + val selector = new ChiSqSelector() + .setNumTopFeatures(1) + .setFeaturesCol("features") + .setLabelCol("clicked") + .setOutputCol("selectedFeatures") + + val result = selector.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println From f292018f8e57779debc04998456ec875f628133b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 1 Dec 2015 15:26:10 -0800 Subject: [PATCH 1622/1824] [SPARK-12002][STREAMING][PYSPARK] Fix python direct stream checkpoint recovery issue Fixed a minor race condition in #10017 Closes #10017 Author: jerryshao Author: Shixiong Zhu Closes #10074 from zsxwing/review-pr10017. --- python/pyspark/streaming/tests.py | 49 +++++++++++++++++++++++++++++++ python/pyspark/streaming/util.py | 13 ++++---- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a647e6bf3958..d50c6b8d4a42 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1149,6 +1149,55 @@ def test_topic_and_partition_equality(self): self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_transform_with_checkpoint(self): + """Test the Python direct Kafka stream transform with checkpoint correctly recovered.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + offsetRanges = [] + + def transformWithOffsetRanges(rdd): + for o in rdd.offsetRanges(): + offsetRanges.append(o) + return rdd + + self.ssc.stop(False) + self.ssc = None + tmpdir = "checkpoint-test-%d" % random.randint(0, 10000) + + def setup(): + ssc = StreamingContext(self.sc, 0.5) + ssc.checkpoint(tmpdir) + stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams) + stream.transform(transformWithOffsetRanges).count().pprint() + return ssc + + try: + ssc1 = StreamingContext.getOrCreate(tmpdir, setup) + ssc1.start() + self.wait_for(offsetRanges, 1) + self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + + # To make sure some checkpoint is written + time.sleep(3) + ssc1.stop(False) + ssc1 = None + + # Restart again to make sure the checkpoint is recovered correctly + ssc2 = StreamingContext.getOrCreate(tmpdir, setup) + ssc2.start() + ssc2.awaitTermination(3) + ssc2.stop(stopSparkContext=False, stopGraceFully=True) + ssc2 = None + finally: + shutil.rmtree(tmpdir) + @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_message_handler(self): """Test Python direct Kafka RDD MessageHandler.""" diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index c7f02bca2ae3..abbbf6eb9394 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -37,11 +37,11 @@ def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers - self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) + self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser) self.failure = None def rdd_wrapper(self, func): - self._rdd_wrapper = func + self.rdd_wrap_func = func return self def call(self, milliseconds, jrdds): @@ -59,7 +59,7 @@ def call(self, milliseconds, jrdds): if len(sers) < len(jrdds): sers += (sers[0],) * (len(jrdds) - len(sers)) - rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None + rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None for jrdd, ser in zip(jrdds, sers)] t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) @@ -101,7 +101,8 @@ def dumps(self, id): self.failure = None try: func = self.gateway.gateway_property.pool[id] - return bytearray(self.serializer.dumps((func.func, func.deserializers))) + return bytearray(self.serializer.dumps(( + func.func, func.rdd_wrap_func, func.deserializers))) except: self.failure = traceback.format_exc() @@ -109,8 +110,8 @@ def loads(self, data): # Clear the failure self.failure = None try: - f, deserializers = self.serializer.loads(bytes(data)) - return TransformFunction(self.ctx, f, *deserializers) + f, wrap_func, deserializers = self.serializer.loads(bytes(data)) + return TransformFunction(self.ctx, f, *deserializers).rdd_wrapper(wrap_func) except: self.failure = traceback.format_exc() From ef6790fdc3b70b9d6184ec2b3d926e4b0e4b15f6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 2 Dec 2015 07:29:45 +0800 Subject: [PATCH 1623/1824] [SPARK-12075][SQL] Speed up HiveComparisionTest by avoiding / speeding up TestHive.reset() When profiling HiveCompatibilitySuite, I noticed that most of the time seems to be spent in expensive `TestHive.reset()` calls. This patch speeds up suites based on HiveComparisionTest, such as HiveCompatibilitySuite, with the following changes: - Avoid `TestHive.reset()` whenever possible: - Use a simple set of heuristics to guess whether we need to call `reset()` in between tests. - As a safety-net, automatically re-run failed tests by calling `reset()` before the re-attempt. - Speed up the expensive parts of `TestHive.reset()`: loading the `src` and `srcpart` tables took roughly 600ms per test, so we now avoid this by using a simple heuristic which only loads those tables by tests that reference them. This is based on simple string matching over the test queries which errs on the side of loading in more situations than might be strictly necessary. After these changes, HiveCompatibilitySuite seems to run in about 10 minutes. This PR is a revival of #6663, an earlier experimental PR from June, where I played around with several possible speedups for this suite. Author: Josh Rosen Closes #10055 from JoshRosen/speculative-testhive-reset. --- .../apache/spark/sql/hive/test/TestHive.scala | 7 -- .../hive/execution/HiveComparisonTest.scala | 67 +++++++++++++++++-- .../hive/execution/HiveQueryFileTest.scala | 2 +- 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6883d305cbea..2e2d201bf254 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -443,13 +443,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { defaultOverrides() runSqlHive("USE default") - - // Just loading src makes a lot of tests pass. This is because some tests do something like - // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our - // Analyzer and thus the test table auto-loading mechanism. - // Remove after we handle more DDL operations natively. - loadTestTable("src") - loadTestTable("srcpart") } catch { case e: Exception => logError("FATAL ERROR: Failed to reset TestDB state.", e) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index aa95ba94fa87..4455430aa727 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -209,7 +209,11 @@ abstract class HiveComparisonTest } val installHooksCommand = "(?i)SET.*hooks".r - def createQueryTest(testCaseName: String, sql: String, reset: Boolean = true) { + def createQueryTest( + testCaseName: String, + sql: String, + reset: Boolean = true, + tryWithoutResettingFirst: Boolean = false) { // testCaseName must not contain ':', which is not allowed to appear in a filename of Windows assert(!testCaseName.contains(":")) @@ -240,9 +244,6 @@ abstract class HiveComparisonTest test(testCaseName) { logDebug(s"=== HIVE TEST: $testCaseName ===") - // Clear old output for this testcase. - outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) - val sqlWithoutComment = sql.split("\n").filterNot(l => l.matches("--.*(?<=[^\\\\]);")).mkString("\n") val allQueries = @@ -269,11 +270,32 @@ abstract class HiveComparisonTest }.mkString("\n== Console version of this test ==\n", "\n", "\n") } - try { + def doTest(reset: Boolean, isSpeculative: Boolean = false): Unit = { + // Clear old output for this testcase. + outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) + if (reset) { TestHive.reset() } + // Many tests drop indexes on src and srcpart at the beginning, so we need to load those + // tables here. Since DROP INDEX DDL is just passed to Hive, it bypasses the analyzer and + // thus the tables referenced in those DDL commands cannot be extracted for use by our + // test table auto-loading mechanism. In addition, the tests which use the SHOW TABLES + // command expect these tables to exist. + val hasShowTableCommand = queryList.exists(_.toLowerCase.contains("show tables")) + for (table <- Seq("src", "srcpart")) { + val hasMatchingQuery = queryList.exists { query => + val normalizedQuery = query.toLowerCase.stripSuffix(";") + normalizedQuery.endsWith(table) || + normalizedQuery.contains(s"from $table") || + normalizedQuery.contains(s"from default.$table") + } + if (hasShowTableCommand || hasMatchingQuery) { + TestHive.loadTestTable(table) + } + } + val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" @@ -430,12 +452,45 @@ abstract class HiveComparisonTest """.stripMargin stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) - fail(errorMessage) + if (isSpeculative && !reset) { + fail("Failed on first run; retrying") + } else { + fail(errorMessage) + } } } // Touch passed file. new FileOutputStream(new File(passedDirectory, testCaseName)).close() + } + + val canSpeculativelyTryWithoutReset: Boolean = { + val excludedSubstrings = Seq( + "into table", + "create table", + "drop index" + ) + !queryList.map(_.toLowerCase).exists { query => + excludedSubstrings.exists(s => query.contains(s)) + } + } + + try { + try { + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + doTest(reset = false, isSpeculative = true) + } else { + doTest(reset) + } + } catch { + case tf: org.scalatest.exceptions.TestFailedException => + if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { + logWarning("Test failed without reset(); retrying with reset()") + doTest(reset = true) + } else { + throw tf + } + } } catch { case tf: org.scalatest.exceptions.TestFailedException => throw tf case originalException: Exception => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index f7b37dae0a5f..f96c989c4614 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -59,7 +59,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) - createQueryTest(testCaseName, queriesString) + createQueryTest(testCaseName, queriesString, reset = true, tryWithoutResettingFirst = true) } else { // Only output warnings for the built in whitelist as this clutters the output when the user // trying to execute a single test from the commandline. From 47a0abc343550c855e679de12983f43e6fcc0171 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 1 Dec 2015 15:30:21 -0800 Subject: [PATCH 1624/1824] [SPARK-11328][SQL] Improve error message when hitting this issue The issue is that the output commiter is not idempotent and retry attempts will fail because the output file already exists. It is not safe to clean up the file as this output committer is by design not retryable. Currently, the job fails with a confusing file exists error. This patch is a stop gap to tell the user to look at the top of the error log for the proper message. This is difficult to test locally as Spark is hardcoded not to retry. Manually verified by upping the retry attempts. Author: Nong Li Author: Nong Li Closes #10080 from nongli/spark-11328. --- .../datasources/WriterContainer.scala | 22 +++++++++++++++++-- .../DirectParquetOutputCommitter.scala | 3 ++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 1b59b19d9420..ad5536725889 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -124,6 +124,24 @@ private[sql] abstract class BaseWriterContainer( } } + protected def newOutputWriter(path: String): OutputWriter = { + try { + outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) + } catch { + case e: org.apache.hadoop.fs.FileAlreadyExistsException => + if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { + // Spark-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry + // attempts, the task will fail because the output file is created from a prior attempt. + // This often means the most visible error to the user is misleading. Augment the error + // to tell the user to look for the actual error. + throw new SparkException("The output file already exists but this could be due to a " + + "failure from an earlier attempt. Look through the earlier logs or stage page for " + + "the first error.\n File exists error: " + e) + } + throw e + } + } + private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) @@ -234,7 +252,7 @@ private[sql] class DefaultWriterContainer( executorSideSetup(taskContext) val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) configuration.set("spark.sql.sources.output.path", outputPath) - val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) + val writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) var writerClosed = false @@ -403,7 +421,7 @@ private[sql] class DynamicPartitionWriterContainer( val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) configuration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) + val newWriter = super.newOutputWriter(path.toString) newWriter.initConverter(dataSchema) newWriter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 300e8677b312..1a4e99ff10af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -41,7 +41,8 @@ import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetO * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are * left empty). */ -private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) +private[datasources] class DirectParquetOutputCommitter( + outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { val LOG = Log.getLog(classOf[ParquetOutputCommitter]) From 5a8b5fdd6ffa58f015cdadf3f2c6df78e0a388ad Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 1 Dec 2015 15:32:57 -0800 Subject: [PATCH 1625/1824] [SPARK-11788][SQL] surround timestamp/date value with quotes in JDBC data source When query the Timestamp or Date column like the following val filtered = jdbcdf.where($"TIMESTAMP_COLUMN" >= beg && $"TIMESTAMP_COLUMN" < end) The generated SQL query is "TIMESTAMP_COLUMN >= 2015-01-01 00:00:00.0" It should have quote around the Timestamp/Date value such as "TIMESTAMP_COLUMN >= '2015-01-01 00:00:00.0'" Author: Huaxin Gao Closes #9872 from huaxingao/spark-11788. --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 4 +++- .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 57a8a044a37c..392d3ed58e3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp} import java.util.Properties import scala.util.control.NonFatal @@ -267,6 +267,8 @@ private[sql] class JDBCRDD( */ private def compileValue(value: Any): Any = value match { case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" case _ => value } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce..8c24aa3151bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -484,4 +484,15 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) } + + test("Test DataFrame.where for Date and Timestamp") { + // Regression test for bug SPARK-11788 + val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); + val date = java.sql.Date.valueOf("1995-01-01") + val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(rows(0).getAs[java.sql.Timestamp](2) + === java.sql.Timestamp.valueOf("2002-02-20 11:22:33.543543")) + } } From 5872a9d89fe2720c2bcb1fc7494136947a72581c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 1 Dec 2015 16:24:04 -0800 Subject: [PATCH 1626/1824] [SPARK-11352][SQL] Escape */ in the generated comments. https://issues.apache.org/jira/browse/SPARK-11352 Author: Yin Huai Closes #10072 from yhuai/SPARK-11352. --- .../spark/sql/catalyst/expressions/Expression.scala | 10 ++++++++-- .../catalyst/expressions/codegen/CodegenFallback.scala | 2 +- .../sql/catalyst/expressions/CodeGenerationSuite.scala | 9 +++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b55d3653a715..4ee6542455a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -95,7 +95,7 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added // as a function and called in advance. Just use it. - val code = s"/* $this */" + val code = s"/* ${this.toCommentSafeString} */" GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") @@ -103,7 +103,7 @@ abstract class Expression extends TreeNode[Expression] { val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) // Add `this` in the comment. - ve.copy(s"/* $this */\n" + ve.code.trim) + ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) } } @@ -214,6 +214,12 @@ abstract class Expression extends TreeNode[Expression] { } override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") + + /** + * Returns the string representation of this expression that is safe to be put in + * code comments of generated code. + */ + protected def toCommentSafeString: String = this.toString.replace("*/", "\\*\\/") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a31574c251af..26fb143d1e45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -33,7 +33,7 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") s""" - /* expression: ${this} */ + /* expression: ${this.toCommentSafeString} */ java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 002ed16dcfe7..fe754240dcd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -98,4 +98,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4) assert(internalRow === internalRow2) } + + test("*/ in the data") { + // When */ appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape */ to \*\/. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("*/", StringType)), + true, + InternalRow(UTF8String.fromString("*/"))) + } } From e96a70d5ab2e2b43a2df17a550fa9ed2ee0001c4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 1 Dec 2015 17:18:45 -0800 Subject: [PATCH 1627/1824] [SPARK-11596][SQL] In TreeNode's argString, if a TreeNode is not a child of the current TreeNode, we should only return the simpleString. In TreeNode's argString, if a TreeNode is not a child of the current TreeNode, we will only return the simpleString. I tested the [following case provided by Cristian](https://issues.apache.org/jira/browse/SPARK-11596?focusedCommentId=15019241&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15019241). ``` val c = (1 to 20).foldLeft[Option[DataFrame]] (None) { (curr, idx) => println(s"PROCESSING >>>>>>>>>>> $idx") val df = sqlContext.sparkContext.parallelize((0 to 10).zipWithIndex).toDF("A", "B") val union = curr.map(_.unionAll(df)).getOrElse(df) union.cache() Some(union) } c.get.explain(true) ``` Without the change, `c.get.explain(true)` took 100s. With the change, `c.get.explain(true)` took 26ms. https://issues.apache.org/jira/browse/SPARK-11596 Author: Yin Huai Closes #10079 from yhuai/SPARK-11596. --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f1cea07976a3..ad2bd78430a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -380,7 +380,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { case tn: TreeNode[_] if containsChild(tn) => Nil - case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil + case tn: TreeNode[_] => s"(${tn.simpleString})" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil From 1ce4adf55b535518c2e63917a827fac1f2df4e8e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Dec 2015 19:36:34 -0800 Subject: [PATCH 1628/1824] [SPARK-8414] Ensure context cleaner periodic cleanups Garbage collection triggers cleanups. If the driver JVM is huge and there is little memory pressure, we may never clean up shuffle files on executors. This is a problem for long-running applications (e.g. streaming). Author: Andrew Or Closes #10070 from andrewor14/periodic-gc. --- .../org/apache/spark/ContextCleaner.scala | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index d23c1533db75..bc732535fed8 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,12 +18,13 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} +import java.util.concurrent.{TimeUnit, ScheduledExecutorService} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Classes that represent cleaning tasks. @@ -66,6 +67,20 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + private val periodicGCService: ScheduledExecutorService = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("context-cleaner-periodic-gc") + + /** + * How often to trigger a garbage collection in this JVM. + * + * This context cleaner triggers cleanups only when weak references are garbage collected. + * In long-running applications with large driver JVMs, where there is little memory pressure + * on the driver, this may happen very occasionally or not at all. Not cleaning at all may + * lead to executors running out of disk space after a while. + */ + private val periodicGCInterval = + sc.conf.getTimeAsSeconds("spark.cleaner.periodicGC.interval", "30min") + /** * Whether the cleaning thread will block on cleanup tasks (other than shuffle, which * is controlled by the `spark.cleaner.referenceTracking.blocking.shuffle` parameter). @@ -104,6 +119,9 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() + periodicGCService.scheduleAtFixedRate(new Runnable { + override def run(): Unit = System.gc() + }, periodicGCInterval, periodicGCInterval, TimeUnit.SECONDS) } /** @@ -119,6 +137,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { cleaningThread.interrupt() } cleaningThread.join() + periodicGCService.shutdown() } /** Register a RDD for cleanup when it is garbage collected. */ From d96f8c997b9bb5c3d61f513d2c71d67ccf8e85d6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Dec 2015 19:51:12 -0800 Subject: [PATCH 1629/1824] [SPARK-12081] Make unified memory manager work with small heaps The existing `spark.memory.fraction` (default 0.75) gives the system 25% of the space to work with. For small heaps, this is not enough: e.g. default 1GB leaves only 250MB system memory. This is especially a problem in local mode, where the driver and executor are crammed in the same JVM. Members of the community have reported driver OOM's in such cases. **New proposal.** We now reserve 300MB before taking the 75%. For 1GB JVMs, this leaves `(1024 - 300) * 0.75 = 543MB` for execution and storage. This is proposal (1) listed in the [JIRA](https://issues.apache.org/jira/browse/SPARK-12081). Author: Andrew Or Closes #10081 from andrewor14/unified-memory-small-heaps. --- .../spark/memory/UnifiedMemoryManager.scala | 22 +++++++++++++++---- .../memory/UnifiedMemoryManagerSuite.scala | 20 +++++++++++++++++ docs/configuration.md | 4 ++-- docs/tuning.md | 2 +- 4 files changed, 41 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 8be5b0541909..48b4e23433e4 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -26,7 +26,7 @@ import org.apache.spark.storage.{BlockStatus, BlockId} * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that * either side can borrow memory from the other. * - * The region shared between execution and storage is a fraction of the total heap space + * The region shared between execution and storage is a fraction of (the total heap space - 300MB) * configurable through `spark.memory.fraction` (default 0.75). The position of the boundary * within this space is further determined by `spark.memory.storageFraction` (default 0.5). * This means the size of the storage region is 0.75 * 0.5 = 0.375 of the heap space by default. @@ -48,7 +48,7 @@ import org.apache.spark.storage.{BlockStatus, BlockId} */ private[spark] class UnifiedMemoryManager private[memory] ( conf: SparkConf, - maxMemory: Long, + val maxMemory: Long, private val storageRegionSize: Long, numCores: Int) extends MemoryManager( @@ -130,6 +130,12 @@ private[spark] class UnifiedMemoryManager private[memory] ( object UnifiedMemoryManager { + // Set aside a fixed amount of memory for non-storage, non-execution purposes. + // This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve + // sufficient memory for the system even for small heaps. E.g. if we have a 1GB JVM, then + // the memory used for execution and storage will be (1024 - 300) * 0.75 = 543MB by default. + private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024 + def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { val maxMemory = getMaxMemory(conf) new UnifiedMemoryManager( @@ -144,8 +150,16 @@ object UnifiedMemoryManager { * Return the total amount of memory shared between execution and storage, in bytes. */ private def getMaxMemory(conf: SparkConf): Long = { - val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val systemMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val reservedMemory = conf.getLong("spark.testing.reservedMemory", + if (conf.contains("spark.testing")) 0 else RESERVED_SYSTEM_MEMORY_BYTES) + val minSystemMemory = reservedMemory * 1.5 + if (systemMemory < minSystemMemory) { + throw new IllegalArgumentException(s"System memory $systemMemory must " + + s"be at least $minSystemMemory. Please use a larger heap size.") + } + val usableMemory = systemMemory - reservedMemory val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) - (systemMaxMemory * memoryFraction).toLong + (usableMemory * memoryFraction).toLong } } diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 8cebe81c3bff..e97c898a4478 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -182,4 +182,24 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assertEnsureFreeSpaceCalled(ms, 850L) } + test("small heap") { + val systemMemory = 1024 * 1024 + val reservedMemory = 300 * 1024 + val memoryFraction = 0.8 + val conf = new SparkConf() + .set("spark.memory.fraction", memoryFraction.toString) + .set("spark.testing.memory", systemMemory.toString) + .set("spark.testing.reservedMemory", reservedMemory.toString) + val mm = UnifiedMemoryManager(conf, numCores = 1) + val expectedMaxMemory = ((systemMemory - reservedMemory) * memoryFraction).toLong + assert(mm.maxMemory === expectedMaxMemory) + + // Try using a system memory that's too small + val conf2 = conf.clone().set("spark.testing.memory", (reservedMemory / 2).toString) + val exception = intercept[IllegalArgumentException] { + UnifiedMemoryManager(conf2, numCores = 1) + } + assert(exception.getMessage.contains("larger heap size")) + } + } diff --git a/docs/configuration.md b/docs/configuration.md index 741d6b2b37a8..c39b4890851b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -719,8 +719,8 @@ Apart from these, the following properties are also available, and may be useful spark.memory.fraction 0.75 - Fraction of the heap space used for execution and storage. The lower this is, the more - frequently spills and cached data eviction occur. The purpose of this config is to set + Fraction of (heap space - 300MB) used for execution and storage. The lower this is, the + more frequently spills and cached data eviction occur. The purpose of this config is to set aside memory for internal metadata, user data structures, and imprecise size estimation in the case of sparse, unusually large records. Leaving this at the default value is recommended. For more detail, see diff --git a/docs/tuning.md b/docs/tuning.md index a8fe7a453279..e73ed69ffbbf 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -114,7 +114,7 @@ variety of workloads without requiring user expertise of how memory is divided i Although there are two relevant configurations, the typical user should not need to adjust them as the default values are applicable to most workloads: -* `spark.memory.fraction` expresses the size of `M` as a fraction of the total JVM heap space +* `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MB) (default 0.75). The rest of the space (25%) is reserved for user data structures, internal metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually large records. From 96691feae0229fd693c29475620be2c4059dd080 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Dec 2015 20:17:12 -0800 Subject: [PATCH 1630/1824] [SPARK-12077][SQL] change the default plan for single distinct Use try to match the behavior for single distinct aggregation with Spark 1.5, but that's not scalable, we should be robust by default, have a flag to address performance regression for low cardinality aggregation. cc yhuai nongli Author: Davies Liu Closes #10075 from davies/agg_15. --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5ef3a48c56a8..58adf64e4986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -451,7 +451,7 @@ private[spark] object SQLConf { val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = booleanConf("spark.sql.specializeSingleDistinctAggPlanning", - defaultValue = Some(true), + defaultValue = Some(false), isPublic = false, doc = "When true, if a query only has a single distinct column and it has " + "grouping expressions, we will use our planner rule to handle this distinct " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index dfec139985f7..a4626259b282 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -44,10 +44,10 @@ class PlannerSuite extends SharedSQLContext { fail(s"Could query play aggregation query $query. Is it an aggregation query?")) val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - // For the new aggregation code path, there will be three aggregate operator for + // For the new aggregation code path, there will be four aggregate operator for // distinct aggregations. assert( - aggregations.size == 2 || aggregations.size == 3, + aggregations.size == 2 || aggregations.size == 4, s"The plan of query $query does not have partial aggregations.") } From 8a75a3049539eeef04c0db51736e97070c162b46 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Dec 2015 21:04:52 -0800 Subject: [PATCH 1631/1824] [SPARK-12087][STREAMING] Create new JobConf for every batch in saveAsHadoopFiles The JobConf object created in `DStream.saveAsHadoopFiles` is used concurrently in multiple places: * The JobConf is updated by `RDD.saveAsHadoopFile()` before the job is launched * The JobConf is serialized as part of the DStream checkpoints. These concurrent accesses (updating in one thread, while the another thread is serializing it) can lead to concurrentModidicationException in the underlying Java hashmap using in the internal Hadoop Configuration object. The solution is to create a new JobConf in every batch, that is updated by `RDD.saveAsHadoopFile()`, while the checkpointing serializes the original JobConf. Tests to be added in #9988 will fail reliably without this patch. Keeping this patch really small to make sure that it can be added to previous branches. Author: Tathagata Das Closes #10088 from tdas/SPARK-12087. --- .../apache/spark/streaming/dstream/PairDStreamFunctions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index fb691eed27e3..2762309134eb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -730,7 +730,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) - rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) + rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, + new JobConf(serializableConf.value)) } self.foreachRDD(saveFunc) } From 0f37d1d7ed7f6e34f98f2a3c274918de29e7a1d7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Dec 2015 21:51:33 -0800 Subject: [PATCH 1632/1824] [SPARK-11949][SQL] Check bitmasks to set nullable property Following up #10038. We can use bitmasks to determine which grouping expressions need to be set as nullable. cc yhuai Author: Liang-Chi Hsieh Closes #10067 from viirya/fix-cube-following. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 765327c474e6..d3163dcd4db9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -224,10 +224,15 @@ class Analyzer( case other => Alias(other, other.toString)() } - // TODO: We need to use bitmasks to determine which grouping expressions need to be - // set as nullable. For example, if we have GROUPING SETS ((a,b), a), we do not need - // to change the nullability of a. - val attributeMap = groupByAliases.map(a => (a -> a.toAttribute.withNullability(true))).toMap + val nonNullBitmask = x.bitmasks.reduce(_ & _) + + val attributeMap = groupByAliases.zipWithIndex.map { case (a, idx) => + if ((nonNullBitmask & 1 << idx) == 0) { + (a -> a.toAttribute.withNullability(true)) + } else { + (a -> a.toAttribute) + } + }.toMap val aggregations: Seq[NamedExpression] = x.aggregations.map { // If an expression is an aggregate (contains a AggregateExpression) then we dont change From 4375eb3f48fc7ae90caf6c21a0d3ab0b66bf4efa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 1 Dec 2015 22:41:48 -0800 Subject: [PATCH 1633/1824] [SPARK-12090] [PYSPARK] consider shuffle in coalesce() Author: Davies Liu Closes #10090 from davies/fix_coalesce. --- python/pyspark/rdd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b4d59647b2b..00bb9a62e904 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2015,7 +2015,7 @@ def coalesce(self, numPartitions, shuffle=False): >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() [[1, 2, 3, 4, 5]] """ - jrdd = self._jrdd.coalesce(numPartitions) + jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, self._jrdd_deserializer) def zip(self, other): From 128c29035b4e7383cc3a9a6c7a9ab6136205ac6c Mon Sep 17 00:00:00 2001 From: Jeroen Schot Date: Wed, 2 Dec 2015 09:40:07 +0000 Subject: [PATCH 1634/1824] [SPARK-3580][CORE] Add Consistent Method To Get Number of RDD Partitions Across Different Languages I have tried to address all the comments in pull request https://github.com/apache/spark/pull/2447. Note that the second commit (using the new method in all internal code of all components) is quite intrusive and could be omitted. Author: Jeroen Schot Closes #9767 from schot/master. --- .../org/apache/spark/api/java/JavaRDDLike.scala | 5 +++++ core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 +++++++- .../test/java/org/apache/spark/JavaAPISuite.java | 13 +++++++++++++ .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 1 + project/MimaExcludes.scala | 4 ++++ 5 files changed, 30 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 1e9d4f1803a8..0e4d7dce0f2f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -28,6 +28,7 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark._ +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap @@ -62,6 +63,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** Set of partitions in this RDD. */ def partitions: JList[Partition] = rdd.partitions.toSeq.asJava + /** Return the number of partitions in this RDD. */ + @Since("1.6.0") + def getNumPartitions: Int = rdd.getNumPartitions + /** The partitioner of this RDD. */ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8b3731d93578..9fe9d83a705b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.api.java.JavaRDD import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -242,6 +242,12 @@ abstract class RDD[T: ClassTag]( } } + /** + * Returns the number of partitions of this RDD. + */ + @Since("1.6.0") + final def getNumPartitions: Int = partitions.length + /** * Get the preferred locations of a partition, taking into account whether the * RDD is checkpointed. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 4d4e9820500e..11f1248c24d3 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -973,6 +973,19 @@ public Iterator call(Integer index, Iterator iter) { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + @Test + public void getNumPartitions(){ + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaDoubleRDD rdd2 = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0), 2); + JavaPairRDD rdd3 = sc.parallelizePairs(Arrays.asList( + new Tuple2<>("a", 1), + new Tuple2<>("aa", 2), + new Tuple2<>("aaa", 3) + ), 2); + Assert.assertEquals(3, rdd1.getNumPartitions()); + Assert.assertEquals(2, rdd2.getNumPartitions()); + Assert.assertEquals(2, rdd3.getNumPartitions()); + } @Test public void repartition() { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 5f718ea9f7be..46ed5c04f433 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -34,6 +34,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.getNumPartitions === 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 566bfe8efb7a..d3a3c0ceb68c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -155,6 +155,10 @@ object MimaExcludes { "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") + ) ++ Seq( + // SPARK-3580 Add getNumPartitions method to JavaRDD + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") ) case v if v.startsWith("1.5") => Seq( From a1542ce2f33ad365ff437d2d3014b9de2f6670e5 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 2 Dec 2015 09:36:12 -0800 Subject: [PATCH 1635/1824] [SPARK-12094][SQL] Prettier tree string for TreeNode When examining plans of complex queries with multiple joins, a pain point of mine is that, it's hard to immediately see the sibling node of a specific query plan node. This PR adds tree lines for the tree string of a `TreeNode`, so that the result can be visually more intuitive. Author: Cheng Lian Closes #10099 from liancheng/prettier-tree-string. --- .../spark/sql/catalyst/trees/TreeNode.scala | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index ad2bd78430a6..dfea583e0146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -393,7 +393,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString: String = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, Nil, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. @@ -419,12 +419,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } - /** Appends the string represent of this node and its children to the given StringBuilder. */ - protected def generateTreeString(depth: Int, builder: StringBuilder): StringBuilder = { - builder.append(" " * depth) + /** + * Appends the string represent of this node and its children to the given StringBuilder. + * + * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at + * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and + * `lastChildren` for the root node should be empty. + */ + protected def generateTreeString( + depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = { + if (depth > 0) { + lastChildren.init.foreach { isLast => + val prefixFragment = if (isLast) " " else ": " + builder.append(prefixFragment) + } + + val branch = if (lastChildren.last) "+- " else ":- " + builder.append(branch) + } + builder.append(simpleString) builder.append("\n") - children.foreach(_.generateTreeString(depth + 1, builder)) + + if (children.nonEmpty) { + children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) + children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + } + builder } From 452690ba1cc3c667bdd9f3022c43c9a10267880b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 2 Dec 2015 13:44:01 -0800 Subject: [PATCH 1636/1824] [SPARK-12001] Allow partially-stopped StreamingContext to be completely stopped If `StreamingContext.stop()` is interrupted midway through the call, the context will be marked as stopped but certain state will have not been cleaned up. Because `state = STOPPED` will be set, subsequent `stop()` calls will be unable to finish stopping the context, preventing any new StreamingContexts from being created. This patch addresses this issue by only marking the context as `STOPPED` once the `stop()` has successfully completed which allows `stop()` to be called a second time in order to finish stopping the context in case the original `stop()` call was interrupted. I discovered this issue by examining logs from a failed Jenkins run in which this race condition occurred in `FailureSuite`, leaking an unstoppable context and causing all subsequent tests to fail. Author: Josh Rosen Closes #9982 from JoshRosen/SPARK-12001. --- .../spark/streaming/StreamingContext.scala | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 6fb8ad38abce..cf843e3e8b8e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -699,28 +699,33 @@ class StreamingContext private[streaming] ( " AsynchronousListenerBus") } synchronized { - try { - state match { - case INITIALIZED => - logWarning("StreamingContext has not been started yet") - case STOPPED => - logWarning("StreamingContext has already been stopped") - case ACTIVE => - scheduler.stop(stopGracefully) - // Removing the streamingSource to de-register the metrics on stop() - env.metricsSystem.removeSource(streamingSource) - uiTab.foreach(_.detach()) - StreamingContext.setActiveContext(null) - waiter.notifyStop() - if (shutdownHookRef != null) { - shutdownHookRefToRemove = shutdownHookRef - shutdownHookRef = null - } - logInfo("StreamingContext stopped successfully") - } - } finally { - // The state should always be Stopped after calling `stop()`, even if we haven't started yet - state = STOPPED + // The state should always be Stopped after calling `stop()`, even if we haven't started yet + state match { + case INITIALIZED => + logWarning("StreamingContext has not been started yet") + state = STOPPED + case STOPPED => + logWarning("StreamingContext has already been stopped") + state = STOPPED + case ACTIVE => + // It's important that we don't set state = STOPPED until the very end of this case, + // since we need to ensure that we're still able to call `stop()` to recover from + // a partially-stopped StreamingContext which resulted from this `stop()` call being + // interrupted. See SPARK-12001 for more details. Because the body of this case can be + // executed twice in the case of a partial stop, all methods called here need to be + // idempotent. + scheduler.stop(stopGracefully) + // Removing the streamingSource to de-register the metrics on stop() + env.metricsSystem.removeSource(streamingSource) + uiTab.foreach(_.detach()) + StreamingContext.setActiveContext(null) + waiter.notifyStop() + if (shutdownHookRef != null) { + shutdownHookRefToRemove = shutdownHookRef + shutdownHookRef = null + } + logInfo("StreamingContext stopped successfully") + state = STOPPED } } if (shutdownHookRefToRemove != null) { From de07d06abecf3516c95d099b6c01a86e0c8cfd8c Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 2 Dec 2015 14:15:54 -0800 Subject: [PATCH 1637/1824] [SPARK-10266][DOCUMENTATION, ML] Fixed @Since annotation for ml.tunning cc mengxr noel-smith I worked on this issues based on https://github.com/apache/spark/pull/8729. ehsanmok thank you for your contricution! Author: Yu ISHIKAWA Author: Ehsan M.Kermani Closes #9338 from yu-iskw/JIRA-10266. --- .../spark/ml/tuning/CrossValidator.scala | 34 ++++++++++++++----- .../spark/ml/tuning/ParamGridBuilder.scala | 14 ++++++-- .../ml/tuning/TrainValidationSplit.scala | 26 +++++++++++--- 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 83a904837426..5c09f1aaff80 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -19,18 +19,18 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path -import org.json4s.{JObject, DefaultFormats} import org.json4s.jackson.JsonMethods._ +import org.json4s.{DefaultFormats, JObject} -import org.apache.spark.ml.classification.OneVsRestParams -import org.apache.spark.ml.feature.RFormulaModel -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ +import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ -import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -58,26 +58,34 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { * :: Experimental :: * K-fold cross validation. */ +@Since("1.2.0") @Experimental -class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] +class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) + extends Estimator[CrossValidatorModel] with CrossValidatorParams with MLWritable with Logging { + @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) private val f2jBLAS = new F2jBLAS /** @group setParam */ + @Since("1.2.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.2.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.2.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.2.0") def setNumFolds(value: Int): this.type = set(numFolds, value) + @Since("1.4.0") override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -116,10 +124,12 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + @Since("1.4.0") override def validateParams(): Unit = { super.validateParams() val est = $(estimator) @@ -128,6 +138,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } } + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidator = { val copied = defaultCopy(extra).asInstanceOf[CrossValidator] if (copied.isDefined(estimator)) { @@ -308,26 +319,31 @@ object CrossValidator extends MLReadable[CrossValidator] { * @param avgMetrics Average cross-validation metrics for each paramMap in * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ +@Since("1.2.0") @Experimental class CrossValidatorModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val avgMetrics: Array[Double]) + @Since("1.4.0") override val uid: String, + @Since("1.2.0") val bestModel: Model[_], + @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { + @Since("1.4.0") override def validateParams(): Unit = { bestModel.validateParams() } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.4.0") override def copy(extra: ParamMap): CrossValidatorModel = { val copied = new CrossValidatorModel( uid, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index 98a8f0330ca4..b836d2a2340e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,21 +20,23 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ /** * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ +@Since("1.2.0") @Experimental -class ParamGridBuilder { +class ParamGridBuilder @Since("1.2.0") { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") def baseOn(paramMap: ParamMap): this.type = { baseOn(paramMap.toSeq: _*) this @@ -43,6 +45,7 @@ class ParamGridBuilder { /** * Sets the given parameters in this grid to fixed values. */ + @Since("1.2.0") @varargs def baseOn(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => @@ -54,6 +57,7 @@ class ParamGridBuilder { /** * Adds a param with multiple values (overwrites if the input param exists). */ + @Since("1.2.0") def addGrid[T](param: Param[T], values: Iterable[T]): this.type = { paramGrid.put(param, values) this @@ -64,6 +68,7 @@ class ParamGridBuilder { /** * Adds a double param with multiple values. */ + @Since("1.2.0") def addGrid(param: DoubleParam, values: Array[Double]): this.type = { addGrid[Double](param, values) } @@ -71,6 +76,7 @@ class ParamGridBuilder { /** * Adds a int param with multiple values. */ + @Since("1.2.0") def addGrid(param: IntParam, values: Array[Int]): this.type = { addGrid[Int](param, values) } @@ -78,6 +84,7 @@ class ParamGridBuilder { /** * Adds a float param with multiple values. */ + @Since("1.2.0") def addGrid(param: FloatParam, values: Array[Float]): this.type = { addGrid[Float](param, values) } @@ -85,6 +92,7 @@ class ParamGridBuilder { /** * Adds a long param with multiple values. */ + @Since("1.2.0") def addGrid(param: LongParam, values: Array[Long]): this.type = { addGrid[Long](param, values) } @@ -92,6 +100,7 @@ class ParamGridBuilder { /** * Adds a boolean param with true and false. */ + @Since("1.2.0") def addGrid(param: BooleanParam): this.type = { addGrid[Boolean](param, Array(true, false)) } @@ -99,6 +108,7 @@ class ParamGridBuilder { /** * Builds and returns all combinations of parameters specified by the param grid. */ + @Since("1.2.0") def build(): Array[ParamMap] = { var paramMaps = Array(new ParamMap) paramGrid.foreach { case (param, values) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 73a14b831015..adf06302047a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} @@ -51,24 +51,32 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { * and uses evaluation metric on the validation set to select the best model. * Similar to [[CrossValidator]], but only splits the set once. */ +@Since("1.5.0") @Experimental -class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] +class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) + extends Estimator[TrainValidationSplitModel] with TrainValidationSplitParams with Logging { + @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) /** @group setParam */ + @Since("1.5.0") def setEstimator(value: Estimator[_]): this.type = set(estimator, value) /** @group setParam */ + @Since("1.5.0") def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) /** @group setParam */ + @Since("1.5.0") def setEvaluator(value: Evaluator): this.type = set(evaluator, value) /** @group setParam */ + @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) + @Since("1.5.0") override def fit(dataset: DataFrame): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -108,10 +116,12 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + @Since("1.5.0") override def validateParams(): Unit = { super.validateParams() val est = $(estimator) @@ -120,6 +130,7 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali } } + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplit = { val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit] if (copied.isDefined(estimator)) { @@ -140,26 +151,31 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali * @param bestModel Estimator determined best model. * @param validationMetrics Evaluated validation metrics. */ +@Since("1.5.0") @Experimental class TrainValidationSplitModel private[ml] ( - override val uid: String, - val bestModel: Model[_], - val validationMetrics: Array[Double]) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val bestModel: Model[_], + @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { + @Since("1.5.0") override def validateParams(): Unit = { bestModel.validateParams() } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): TrainValidationSplitModel = { val copied = new TrainValidationSplitModel ( uid, From d0d7ec533062151269b300ed455cf150a69098c0 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Thu, 3 Dec 2015 08:48:49 +0800 Subject: [PATCH 1638/1824] [SPARK-12093][SQL] Fix the error of comment in DDLParser Author: Yadong Qi Closes #10096 from watermen/patch-1. --- .../apache/spark/sql/execution/datasources/DDLParser.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 6969b423d01b..f22508b21090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -66,15 +66,15 @@ class DDLParser(parseQuery: String => LogicalPlan) protected def start: Parser[LogicalPlan] = ddl /** - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * or - * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS] + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable(intField int, stringField string...) * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * or - * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS] + * `CREATE [TEMPORARY] TABLE [IF NOT EXISTS] avroTable * USING org.apache.spark.sql.avro * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` * AS SELECT ... From 9bb695b7a82d837e2c7a724514ea6b203efb5364 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 2 Dec 2015 17:19:31 -0800 Subject: [PATCH 1639/1824] [SPARK-12000] do not specify arg types when reference a method in ScalaDoc This fixes SPARK-12000, verified on my local with JDK 7. It seems that `scaladoc` try to match method names and messed up with annotations. cc: JoshRosen jkbradley Author: Xiangrui Meng Closes #10114 from mengxr/SPARK-12000.2. --- .../org/apache/spark/mllib/clustering/BisectingKMeans.scala | 2 +- .../apache/spark/mllib/clustering/BisectingKMeansModel.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 29a7aa0bb63f..82adfa6ffd59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -214,7 +214,7 @@ class BisectingKMeans private ( } /** - * Java-friendly version of [[run(RDD[Vector])*]] + * Java-friendly version of [[run()]]. */ def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 5015f1540d92..f942e5613ffa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -64,7 +64,7 @@ class BisectingKMeansModel @Since("1.6.0") ( } /** - * Java-friendly version of [[predict(RDD[Vector])*]] + * Java-friendly version of [[predict()]]. */ @Since("1.6.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = @@ -88,7 +88,7 @@ class BisectingKMeansModel @Since("1.6.0") ( } /** - * Java-friendly version of [[computeCost(RDD[Vector])*]]. + * Java-friendly version of [[computeCost()]]. */ @Since("1.6.0") def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) From ae402533738be06ac802914ed3e48f0d5fa54cbe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 3 Dec 2015 11:12:02 +0800 Subject: [PATCH 1640/1824] [SPARK-12082][FLAKY-TEST] Increase timeouts in NettyBlockTransferSecuritySuite We should try increasing a timeout in NettyBlockTransferSecuritySuite in order to reduce that suite's flakiness in Jenkins. Author: Josh Rosen Closes #10113 from JoshRosen/SPARK-12082. --- .../spark/network/netty/NettyBlockTransferSecuritySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 3940527fb874..98da94139f7f 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -148,7 +148,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } }) - Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS)) + Await.ready(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get } } From ec2b6c26c9b6bd59d29b5d7af2742aca7e6e0b07 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 3 Dec 2015 11:21:24 +0800 Subject: [PATCH 1641/1824] [SPARK-12109][SQL] Expressions's simpleString should delegate to its toString. https://issues.apache.org/jira/browse/SPARK-12109 The change of https://issues.apache.org/jira/browse/SPARK-11596 exposed the problem. In the sql plan viz, the filter shows ![image](https://cloud.githubusercontent.com/assets/2072857/11547075/1a285230-9906-11e5-8481-2bb451e35ef1.png) After changes in this PR, the viz is back to normal. ![image](https://cloud.githubusercontent.com/assets/2072857/11547080/2bc570f4-9906-11e5-8897-3b3bff173276.png) Author: Yin Huai Closes #10111 from yhuai/SPARK-12109. --- .../org/apache/spark/sql/catalyst/expressions/Expression.scala | 3 ++- .../spark/sql/catalyst/expressions/windowExpressions.scala | 3 --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4ee6542455a6..614f0c075fd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -207,12 +207,13 @@ abstract class Expression extends TreeNode[Expression] { }.toString } - private def flatArguments = productIterator.flatMap { case t: Traversable[_] => t case single => single :: Nil } + override def simpleString: String = toString + override def toString: String = prettyName + flatArguments.mkString("(", ",", ")") /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 09ec0e333aa4..1680aa8252ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -71,9 +71,6 @@ case class WindowSpecDefinition( childrenResolved && checkInputDataTypes().isSuccess && frameSpecification.isInstanceOf[SpecifiedWindowFrame] - - override def toString: String = simpleString - override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index dfea583e0146..d838d845d20f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -380,7 +380,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { case tn: TreeNode[_] if containsChild(tn) => Nil - case tn: TreeNode[_] => s"(${tn.simpleString})" :: Nil + case tn: TreeNode[_] => s"${tn.simpleString}" :: Nil case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil From 5349851f368a1b5dab8a99c0d51c9638ce7aec56 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 3 Dec 2015 08:42:21 +0000 Subject: [PATCH 1642/1824] =?UTF-8?q?[SPARK-12088][SQL]=20check=20connecti?= =?UTF-8?q?on.isClosed=20before=20calling=20connection=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In Java Spec java.sql.Connection, it has boolean getAutoCommit() throws SQLException Throws: SQLException - if a database access error occurs or this method is called on a closed connection So if conn.getAutoCommit is called on a closed connection, a SQLException will be thrown. Even though the code catch the SQLException and program can continue, I think we should check conn.isClosed before calling conn.getAutoCommit to avoid the unnecessary SQLException. Author: Huaxin Gao Closes #10095 from huaxingao/spark-12088. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 392d3ed58e3c..b9dd7f6b4099 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -498,7 +498,7 @@ private[sql] class JDBCRDD( } try { if (null != conn) { - if (!conn.getAutoCommit && !conn.isClosed) { + if (!conn.isClosed && !conn.getAutoCommit) { try { conn.commit() } catch { From 7470d9edbb0a45e714c96b5d55eff30724c0653a Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 3 Dec 2015 15:36:28 +0000 Subject: [PATCH 1643/1824] [DOCUMENTATION][MLLIB] typo in mllib doc \cc mengxr Author: Jeff Zhang Closes #10093 from zjffdu/mllib_typo. --- docs/ml-features.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 5f888775553f..05c2c96c5ec5 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1232,7 +1232,7 @@ lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) * `withStd`: True by default. Scales the data to unit standard deviation. * `withMean`: False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception. -`StandardScaler` is a `Model` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. +`StandardScaler` is an `Estimator` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features. Note that if the standard deviation of a feature is zero, it will return default `0.0` value in the `Vector` for that feature. From 95b3cf125b905b4ef8705c46f2ef255377b0a9dc Mon Sep 17 00:00:00 2001 From: microwishing Date: Thu, 3 Dec 2015 16:09:05 +0000 Subject: [PATCH 1644/1824] [DOCUMENTATION][KAFKA] fix typo in kafka/OffsetRange.scala this is to fix some typo in external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala Author: microwishing Closes #10121 from microwishing/master. --- .../scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala | 2 +- .../scala/org/apache/spark/streaming/kafka/OffsetRange.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 86394ea8a685..45a6982b9afe 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -151,7 +151,7 @@ private[kafka] class KafkaTestUtils extends Logging { } } - /** Create a Kafka topic and wait until it propagated to the whole cluster */ + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ def createTopic(topic: String): Unit = { AdminUtils.createTopic(zkClient, topic, 1, 1) // wait until metadata is propagated diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 8a5f37149451..d9b856e4697a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import kafka.common.TopicAndPartition /** - * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the + * Represents any object that has a collection of [[OffsetRange]]s. This can be used to access the * offset ranges in RDDs generated by the direct Kafka DStream (see * [[KafkaUtils.createDirectStream()]]). * {{{ From 43c575cb1766b32c74db17216194a8a74119b759 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Thu, 3 Dec 2015 09:22:21 -0800 Subject: [PATCH 1645/1824] [SPARK-12116][SPARKR][DOCS] document how to workaround function name conflicts with dplyr shivaram Author: felixcheung Closes #10119 from felixcheung/rdocdplyrmasked. --- docs/sparkr.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index cfb9b41350f4..01148786b79d 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -384,5 +384,6 @@ The following functions are masked by the SparkR package: +Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. + You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) - From 8fa3e474a8ba180188361c0ad7e2704c3e2258d3 Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Thu, 3 Dec 2015 10:33:06 -0800 Subject: [PATCH 1646/1824] [SPARK-11314][YARN] add service API and test service for Yarn Cluster schedulers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is purely the yarn/src/main and yarn/src/test bits of the YARN ATS integration: the extension model to load and run implementations of `SchedulerExtensionService` in the yarn cluster scheduler process —and to stop them afterwards. There's duplication between the two schedulers, yarn-client and yarn-cluster, at least in terms of setting everything up, because the common superclass, `YarnSchedulerBackend` is in spark-core, and the extension services need the YARN app/attempt IDs. If you look at how the the extension services are loaded, the case class `SchedulerExtensionServiceBinding` is used to pass in config info -currently just the spark context and the yarn IDs, of which one, the attemptID, will be null when running client-side. I'm passing in a case class to ensure that it would be possible in future to add extra arguments to the binding class, yet, as the method signature will not have changed, still be able to load existing services. There's no functional extension service here, just one for testing. The real tests come in the bigger pull requests. At the same time, there's no restriction of this extension service purely to the ATS history publisher. Anything else that wants to listen to the spark context and publish events could use this, and I'd also consider writing one for the YARN-913 registry service, so that the URLs of the web UI would be locatable through that (low priority; would make more sense if integrated with a REST client). There's no minicluster test. Given the test execution overhead of setting up minicluster tests, it'd probably be better to add an extension service into one of the existing tests. Author: Steve Loughran Closes #9182 from steveloughran/stevel/feature/SPARK-1537-service. --- .../spark/deploy/yarn/ApplicationMaster.scala | 8 + .../cluster/SchedulerExtensionService.scala | 154 ++++++++++++++++++ .../cluster/YarnClientSchedulerBackend.scala | 22 +-- .../cluster/YarnClusterSchedulerBackend.scala | 20 +-- .../cluster/YarnSchedulerBackend.scala | 70 +++++++- .../ExtensionServiceIntegrationSuite.scala | 71 ++++++++ .../cluster/SimpleExtensionService.scala | 34 ++++ .../cluster/StubApplicationAttemptId.scala | 48 ++++++ .../scheduler/cluster/StubApplicationId.scala | 42 +++++ 9 files changed, 431 insertions(+), 38 deletions(-) create mode 100644 yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala rename {core => yarn}/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala (81%) create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala create mode 100644 yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 50ae7ffeec4c..13ef4dfd6416 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -117,6 +117,10 @@ private[spark] class ApplicationMaster( private var delegationTokenRenewerOption: Option[AMDelegationTokenRenewer] = None + def getAttemptId(): ApplicationAttemptId = { + client.getAttemptId() + } + final def run(): Int = { try { val appAttemptId = client.getAttemptId() @@ -662,6 +666,10 @@ object ApplicationMaster extends Logging { master.sparkContextStopped(sc) } + private[spark] def getAttemptId(): ApplicationAttemptId = { + master.getAttemptId + } + } /** diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala new file mode 100644 index 000000000000..c06452184539 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.util.Utils + +/** + * An extension service that can be loaded into a Spark YARN scheduler. + * A Service that can be started and stopped. + * + * 1. For implementations to be loadable by `SchedulerExtensionServices`, + * they must provide an empty constructor. + * 2. The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ +trait SchedulerExtensionService { + + /** + * Start the extension service. This should be a no-op if + * called more than once. + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit + + /** + * Stop the service + * The `stop()` operation MUST be idempotent, and succeed even if `start()` was + * never invoked. + */ + def stop(): Unit +} + +/** + * Binding information for a [[SchedulerExtensionService]]. + * + * The attempt ID will be set if the service is started within a YARN application master; + * there is then a different attempt ID for every time that AM is restarted. + * When the service binding is instantiated in client mode, there's no attempt ID, as it lacks + * this information. + * @param sparkContext current spark context + * @param applicationId YARN application ID + * @param attemptId YARN attemptID. This will always be unset in client mode, and always set in + * cluster mode. + */ +case class SchedulerExtensionServiceBinding( + sparkContext: SparkContext, + applicationId: ApplicationId, + attemptId: Option[ApplicationAttemptId] = None) + +/** + * Container for [[SchedulerExtensionService]] instances. + * + * Loads Extension Services from the configuration property + * `"spark.yarn.services"`, instantiates and starts them. + * When stopped, it stops all child entries. + * + * The order in which child extension services are started and stopped + * is undefined. + */ +private[spark] class SchedulerExtensionServices extends SchedulerExtensionService + with Logging { + private var serviceOption: Option[String] = None + private var services: List[SchedulerExtensionService] = Nil + private val started = new AtomicBoolean(false) + private var binding: SchedulerExtensionServiceBinding = _ + + /** + * Binding operation will load the named services and call bind on them too; the + * entire set of services are then ready for `init()` and `start()` calls. + * + * @param binding binding to the spark application and YARN + */ + def start(binding: SchedulerExtensionServiceBinding): Unit = { + if (started.getAndSet(true)) { + logWarning("Ignoring re-entrant start operation") + return + } + require(binding.sparkContext != null, "Null context parameter") + require(binding.applicationId != null, "Null appId parameter") + this.binding = binding + val sparkContext = binding.sparkContext + val appId = binding.applicationId + val attemptId = binding.attemptId + logInfo(s"Starting Yarn extension services with app $appId and attemptId $attemptId") + + serviceOption = sparkContext.getConf.getOption(SchedulerExtensionServices.SPARK_YARN_SERVICES) + services = serviceOption + .map { s => + s.split(",").map(_.trim()).filter(!_.isEmpty) + .map { sClass => + val instance = Utils.classForName(sClass) + .newInstance() + .asInstanceOf[SchedulerExtensionService] + // bind this service + instance.start(binding) + logInfo(s"Service $sClass started") + instance + }.toList + }.getOrElse(Nil) + } + + /** + * Get the list of services. + * + * @return a list of services; Nil until the service is started + */ + def getServices: List[SchedulerExtensionService] = services + + /** + * Stop the services; idempotent. + * + */ + override def stop(): Unit = { + if (started.getAndSet(false)) { + logInfo(s"Stopping $this") + services.foreach { s => + Utils.tryLogNonFatalError(s.stop()) + } + } + } + + override def toString(): String = s"""SchedulerExtensionServices + |(serviceOption=$serviceOption, + | services=$services, + | started=$started)""".stripMargin +} + +private[spark] object SchedulerExtensionServices { + + /** + * A list of comma separated services to instantiate in the scheduler + */ + val SPARK_YARN_SERVICES = "spark.yarn.services" +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 20771f655473..0e27a2665e93 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} @@ -33,7 +33,6 @@ private[spark] class YarnClientSchedulerBackend( with Logging { private var client: Client = null - private var appId: ApplicationId = null private var monitorThread: MonitorThread = null /** @@ -54,13 +53,12 @@ private[spark] class YarnClientSchedulerBackend( val args = new ClientArguments(argsArrayBuf.toArray, conf) totalExpectedExecutors = args.numExecutors client = new Client(args, conf) - appId = client.submitApplication() + bindToYarn(client.submitApplication(), None) // SPARK-8687: Ensure all necessary properties have already been set before // we initialize our driver scheduler backend, which serves these properties // to the executors super.start() - waitForApplication() // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver @@ -116,8 +114,8 @@ private[spark] class YarnClientSchedulerBackend( * This assumes both `client` and `appId` have already been set. */ private def waitForApplication(): Unit = { - assert(client != null && appId != null, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId, returnOnRunning = true) // blocking + assert(client != null && appId.isDefined, "Application has not been submitted yet!") + val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true) // blocking if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { @@ -125,7 +123,7 @@ private[spark] class YarnClientSchedulerBackend( "It might have been killed or unable to launch application master.") } if (state == YarnApplicationState.RUNNING) { - logInfo(s"Application $appId has started running.") + logInfo(s"Application ${appId.get} has started running.") } } @@ -141,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { - val (state, _) = client.monitorApplication(appId, logApplicationReport = false) + val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) logError(s"Yarn application has already exited with state $state!") allowInterrupt = false sc.stop() @@ -163,7 +161,7 @@ private[spark] class YarnClientSchedulerBackend( * This assumes both `client` and `appId` have already been set. */ private def asyncMonitorApplication(): MonitorThread = { - assert(client != null && appId != null, "Application has not been submitted yet!") + assert(client != null && appId.isDefined, "Application has not been submitted yet!") val t = new MonitorThread t.setName("Yarn application state monitor") t.setDaemon(true) @@ -193,10 +191,4 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } - override def applicationId(): String = { - Option(appId).map(_.toString).getOrElse { - logWarning("Application ID is not initialized yet.") - super.applicationId - } - } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 50b699f11b21..ced597bed36d 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.SparkContext -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil +import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnSparkHadoopUtil} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils @@ -31,26 +31,12 @@ private[spark] class YarnClusterSchedulerBackend( extends YarnSchedulerBackend(scheduler, sc) { override def start() { + val attemptId = ApplicationMaster.getAttemptId + bindToYarn(attemptId.getApplicationId(), Some(attemptId)) super.start() totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf) } - override def applicationId(): String = - // In YARN Cluster mode, the application ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.id").getOrElse { - logError("Application ID is not set.") - super.applicationId - } - - override def applicationAttemptId(): Option[String] = - // In YARN Cluster mode, the attempt ID is expected to be set, so log an error if it's - // not found. - sc.getConf.getOption("spark.yarn.app.attemptId").orElse { - logError("Application attempt ID is not set.") - super.applicationAttemptId - } - override def getDriverLogUrls: Option[Map[String, String]] = { var driverLogs: Option[Map[String, String]] = None try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala similarity index 81% rename from core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala rename to yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 80da37b09b59..e3dd87798f01 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,17 +17,17 @@ package org.apache.spark.scheduler.cluster -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Future, ExecutionContext} +import scala.concurrent.{ExecutionContext, Future} +import scala.util.control.NonFatal + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.rpc._ -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{ThreadUtils, RpcUtils} - -import scala.util.control.NonFatal +import org.apache.spark.util.{RpcUtils, ThreadUtils} /** * Abstract Yarn scheduler backend that contains common logic @@ -51,6 +51,64 @@ private[spark] abstract class YarnSchedulerBackend( private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) + /** Application ID. */ + protected var appId: Option[ApplicationId] = None + + /** Attempt ID. This is unset for client-mode schedulers */ + private var attemptId: Option[ApplicationAttemptId] = None + + /** Scheduler extension services. */ + private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + + /** + * Bind to YARN. This *must* be done before calling [[start()]]. + * + * @param appId YARN application ID + * @param attemptId Optional YARN attempt ID + */ + protected def bindToYarn(appId: ApplicationId, attemptId: Option[ApplicationAttemptId]): Unit = { + this.appId = Some(appId) + this.attemptId = attemptId + } + + override def start() { + require(appId.isDefined, "application ID unset") + val binding = SchedulerExtensionServiceBinding(sc, appId.get, attemptId) + services.start(binding) + super.start() + } + + override def stop(): Unit = { + try { + super.stop() + } finally { + services.stop() + } + } + + /** + * Get the attempt ID for this run, if the cluster manager supports multiple + * attempts. Applications run in client mode will not have attempt IDs. + * + * @return The application attempt id, if available. + */ + override def applicationAttemptId(): Option[String] = { + attemptId.map(_.toString) + } + + /** + * Get an application ID associated with the job. + * This returns the string value of [[appId]] if set, otherwise + * the locally-generated ID from the superclass. + * @return The application ID + */ + override def applicationId(): String = { + appId.map(_.toString).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } + } + /** * Request executors from the ApplicationMaster by specifying the total number desired. * This includes executors already pending or running. diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala new file mode 100644 index 000000000000..b4d1b0a3d22a --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{LocalSparkContext, Logging, SparkConf, SparkContext, SparkFunSuite} + +/** + * Test the integration with [[SchedulerExtensionServices]] + */ +class ExtensionServiceIntegrationSuite extends SparkFunSuite + with LocalSparkContext with BeforeAndAfter + with Logging { + + val applicationId = new StubApplicationId(0, 1111L) + val attemptId = new StubApplicationAttemptId(applicationId, 1) + + /* + * Setup phase creates the spark context + */ + before { + val sparkConf = new SparkConf() + sparkConf.set(SchedulerExtensionServices.SPARK_YARN_SERVICES, + classOf[SimpleExtensionService].getName()) + sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite") + sc = new SparkContext(sparkConf) + } + + test("Instantiate") { + val services = new SchedulerExtensionServices() + assertResult(Nil, "non-nil service list") { + services.getServices + } + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + services.stop() + } + + test("Contains SimpleExtensionService Service") { + val services = new SchedulerExtensionServices() + try { + services.start(SchedulerExtensionServiceBinding(sc, applicationId)) + val serviceList = services.getServices + assert(serviceList.nonEmpty, "empty service list") + val (service :: Nil) = serviceList + val simpleService = service.asInstanceOf[SimpleExtensionService] + assert(simpleService.started.get, "service not started") + services.stop() + assert(!simpleService.started.get, "service not stopped") + } finally { + services.stop() + } + } +} + + diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala new file mode 100644 index 000000000000..9b8c98cda8da --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import java.util.concurrent.atomic.AtomicBoolean + +private[spark] class SimpleExtensionService extends SchedulerExtensionService { + + /** started flag; set in the `start()` call, stopped in `stop()`. */ + val started = new AtomicBoolean(false) + + override def start(binding: SchedulerExtensionServiceBinding): Unit = { + started.set(true) + } + + override def stop(): Unit = { + started.set(false) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala new file mode 100644 index 000000000000..4b57b9509a65 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} + +/** + * A stub application ID; can be set in constructor and/or updated later. + * @param applicationId application ID + * @param attempt an attempt counter + */ +class StubApplicationAttemptId(var applicationId: ApplicationId, var attempt: Int) + extends ApplicationAttemptId { + + override def setApplicationId(appID: ApplicationId): Unit = { + applicationId = appID + } + + override def getAttemptId: Int = { + attempt + } + + override def setAttemptId(attemptId: Int): Unit = { + attempt = attemptId + } + + override def getApplicationId: ApplicationId = { + applicationId + } + + override def build(): Unit = { + } +} diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala new file mode 100644 index 000000000000..bffa0e09befd --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.ApplicationId + +/** + * Simple Testing Application Id; ID and cluster timestamp are set in constructor + * and cannot be updated. + * @param id app id + * @param clusterTimestamp timestamp + */ +private[spark] class StubApplicationId(id: Int, clusterTimestamp: Long) extends ApplicationId { + override def getId: Int = { + id + } + + override def getClusterTimestamp: Long = { + clusterTimestamp + } + + override def setId(id: Int): Unit = {} + + override def setClusterTimestamp(clusterTimestamp: Long): Unit = {} + + override def build(): Unit = {} +} From 7bc9e1db2c47387ee693bcbeb4a8a2cbe11909cf Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 3 Dec 2015 11:05:12 -0800 Subject: [PATCH 1647/1824] [SPARK-12059][CORE] Avoid assertion error when unexpected state transition met in Master Downgrade to warning log for unexpected state transition. andrewor14 please review, thanks a lot. Author: jerryshao Closes #10091 from jerryshao/SPARK-12059. --- .../main/scala/org/apache/spark/deploy/master/Master.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1355e1ad1b52..04b20e0d6ab9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -257,8 +257,9 @@ private[deploy] class Master( exec.state = state if (state == ExecutorState.RUNNING) { - assert(oldState == ExecutorState.LAUNCHING, - s"executor $execId state transfer from $oldState to RUNNING is illegal") + if (oldState != ExecutorState.LAUNCHING) { + logWarning(s"Executor $execId state transfer from $oldState to RUNNING is unexpected") + } appInfo.resetRetryCount() } From 649be4fa4532dcd3001df8345f9f7e970a3fbc65 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 3 Dec 2015 11:06:25 -0800 Subject: [PATCH 1648/1824] [SPARK-12101][CORE] Fix thread pools that cannot cache tasks in Worker and AppClient `SynchronousQueue` cannot cache any task. This issue is similar to #9978. It's an easy fix. Just use the fixed `ThreadUtils.newDaemonCachedThreadPool`. Author: Shixiong Zhu Closes #10108 from zsxwing/fix-threadpool. --- .../org/apache/spark/deploy/client/AppClient.scala | 10 ++++------ .../org/apache/spark/deploy/worker/Worker.scala | 10 ++++------ .../apache/spark/deploy/yarn/YarnAllocator.scala | 14 ++++---------- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index df6ba7d669ce..1e2f469214b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -68,12 +68,10 @@ private[spark] class AppClient( // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. - private val registerMasterThreadPool = new ThreadPoolExecutor( - 0, - masterRpcAddresses.length, // Make sure we can register with all masters at the same time - 60L, TimeUnit.SECONDS, - new SynchronousQueue[Runnable](), - ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) + private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "appclient-register-master-threadpool", + masterRpcAddresses.length // Make sure we can register with all masters at the same time + ) // A scheduled executor for scheduling the registration actions private val registrationRetryThread = diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 418faf8fc967..1afc1ff59f2f 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -146,12 +146,10 @@ private[deploy] class Worker( // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. - private val registerMasterThreadPool = new ThreadPoolExecutor( - 0, - masterRpcAddresses.size, // Make sure we can register with all masters at the same time - 60L, TimeUnit.SECONDS, - new SynchronousQueue[Runnable](), - ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) + private val registerMasterThreadPool = ThreadUtils.newDaemonCachedThreadPool( + "worker-register-master-threadpool", + masterRpcAddresses.size // Make sure we can register with all masters at the same time + ) var coresUsed = 0 var memoryUsed = 0 diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 73cd9031f025..4e044aa4788d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -25,8 +25,6 @@ import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.collection.JavaConverters._ -import com.google.common.util.concurrent.ThreadFactoryBuilder - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient @@ -40,7 +38,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -117,13 +115,9 @@ private[yarn] class YarnAllocator( // Resource capability requested for each executors private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) - private val launcherPool = new ThreadPoolExecutor( - // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue - sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE, - 1, TimeUnit.MINUTES, - new LinkedBlockingQueue[Runnable](), - new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) - launcherPool.allowCoreThreadTimeOut(true) + private val launcherPool = ThreadUtils.newDaemonCachedThreadPool( + "ContainerLauncher", + sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25)) // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) From 688e521c2833a00069272a6749153d721a0996f6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 3 Dec 2015 11:09:29 -0800 Subject: [PATCH 1649/1824] [SPARK-12108] Make event logs smaller **Problem.** Event logs in 1.6 were much bigger than 1.5. I ran page rank and the event log size in 1.6 was almost 5x that in 1.5. I did a bisect to find that the RDD callsite added in #9398 is largely responsible for this. **Solution.** This patch removes the long form of the callsite (which is not used!) from the event log. This reduces the size of the event log significantly. *Note on compatibility*: if this patch is to be merged into 1.6.0, then it won't break any compatibility. Otherwise, if it is merged into 1.6.1, then we might need to add more backward compatibility handling logic (currently does not exist yet). Author: Andrew Or Closes #10115 from andrewor14/smaller-event-logs. --- .../org/apache/spark/storage/RDDInfo.scala | 4 +-- .../spark/ui/scope/RDDOperationGraph.scala | 4 +-- .../org/apache/spark/util/JsonProtocol.scala | 17 ++------- .../apache/spark/util/JsonProtocolSuite.scala | 35 ++++++++----------- 4 files changed, 20 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 87c1b981e7e1..94e8559bd2e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -28,7 +28,7 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], - val callSite: CallSite = CallSite.empty, + val callSite: String = "", val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { @@ -58,6 +58,6 @@ private[spark] object RDDInfo { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite, rdd.scope) + rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 24274562657b..e9c8a8e299cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -39,7 +39,7 @@ private[ui] case class RDDOperationGraph( rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: CallSite) +private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: String) /** * A directed edge connecting two nodes in an RDDOperationGraph. @@ -178,7 +178,7 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { - val label = s"${node.name} [${node.id}]\n${node.callsite.shortForm}" + val label = s"${node.name} [${node.id}]\n${node.callsite}" s"""${node.id} [label="$label"]""" } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c9beeb25e05a..2d2bd90eb339 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -398,7 +398,7 @@ private[spark] object JsonProtocol { ("RDD ID" -> rddInfo.id) ~ ("Name" -> rddInfo.name) ~ ("Scope" -> rddInfo.scope.map(_.toJson)) ~ - ("Callsite" -> callsiteToJson(rddInfo.callSite)) ~ + ("Callsite" -> rddInfo.callSite) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ @@ -408,11 +408,6 @@ private[spark] object JsonProtocol { ("Disk Size" -> rddInfo.diskSize) } - def callsiteToJson(callsite: CallSite): JValue = { - ("Short Form" -> callsite.shortForm) ~ - ("Long Form" -> callsite.longForm) - } - def storageLevelToJson(storageLevel: StorageLevel): JValue = { ("Use Disk" -> storageLevel.useDisk) ~ ("Use Memory" -> storageLevel.useMemory) ~ @@ -857,9 +852,7 @@ private[spark] object JsonProtocol { val scope = Utils.jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) - val callsite = Utils.jsonOption(json \ "Callsite") - .map(callsiteFromJson) - .getOrElse(CallSite.empty) + val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) @@ -880,12 +873,6 @@ private[spark] object JsonProtocol { rddInfo } - def callsiteFromJson(json: JValue): CallSite = { - val shortForm = (json \ "Short Form").extract[String] - val longForm = (json \ "Long Form").extract[String] - CallSite(shortForm, longForm) - } - def storageLevelFromJson(json: JValue): StorageLevel = { val useDisk = (json \ "Use Disk").extract[Boolean] val useMemory = (json \ "Use Memory").extract[Boolean] diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 3f94ef704191..1939ce5c743b 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -111,7 +111,6 @@ class JsonProtocolSuite extends SparkFunSuite { test("Dependent Classes") { val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) - testCallsite(CallSite("happy", "birthday")) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics( @@ -343,13 +342,13 @@ class JsonProtocolSuite extends SparkFunSuite { // "Scope" and "Parent IDs" were introduced in Spark 1.4.0 // "Callsite" was introduced in Spark 1.6.0 val rddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), - CallSite("short", "long"), Some(new RDDOperationScope("fable"))) + "callsite", Some(new RDDOperationScope("fable"))) val oldRddInfoJson = JsonProtocol.rddInfoToJson(rddInfo) .removeField({ _._1 == "Parent IDs"}) .removeField({ _._1 == "Scope"}) .removeField({ _._1 == "Callsite"}) val expectedRddInfo = new RDDInfo( - 1, "one", 100, StorageLevel.NONE, Seq.empty, CallSite.empty, scope = None) + 1, "one", 100, StorageLevel.NONE, Seq.empty, "", scope = None) assertEquals(expectedRddInfo, JsonProtocol.rddInfoFromJson(oldRddInfoJson)) } @@ -397,11 +396,6 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(info, newInfo) } - private def testCallsite(callsite: CallSite): Unit = { - val newCallsite = JsonProtocol.callsiteFromJson(JsonProtocol.callsiteToJson(callsite)) - assert(callsite === newCallsite) - } - private def testStageInfo(info: StageInfo) { val newInfo = JsonProtocol.stageInfoFromJson(JsonProtocol.stageInfoToJson(info)) assertEquals(info, newInfo) @@ -726,8 +720,7 @@ class JsonProtocolSuite extends SparkFunSuite { } private def makeRddInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, - Seq(1, 4, 7), CallSite(a.toString, b.toString)) + val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7), a.toString) r.numCachedPartitions = c r.memSize = d r.diskSize = e @@ -870,7 +863,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 101, | "Name": "mayor", - | "Callsite": {"Short Form": "101", "Long Form": "201"}, + | "Callsite": "101", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1273,7 +1266,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 1, | "Name": "mayor", - | "Callsite": {"Short Form": "1", "Long Form": "200"}, + | "Callsite": "1", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1317,7 +1310,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 2, | "Name": "mayor", - | "Callsite": {"Short Form": "2", "Long Form": "400"}, + | "Callsite": "2", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1335,7 +1328,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", - | "Callsite": {"Short Form": "3", "Long Form": "401"}, + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1379,7 +1372,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", - | "Callsite": {"Short Form": "3", "Long Form": "600"}, + | "Callsite": "3", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1397,7 +1390,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", - | "Callsite": {"Short Form": "4", "Long Form": "601"}, + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1415,7 +1408,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", - | "Callsite": {"Short Form": "5", "Long Form": "602"}, + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1459,7 +1452,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", - | "Callsite": {"Short Form": "4", "Long Form": "800"}, + | "Callsite": "4", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1477,7 +1470,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", - | "Callsite": {"Short Form": "5", "Long Form": "801"}, + | "Callsite": "5", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1495,7 +1488,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 6, | "Name": "mayor", - | "Callsite": {"Short Form": "6", "Long Form": "802"}, + | "Callsite": "6", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1513,7 +1506,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 7, | "Name": "mayor", - | "Callsite": {"Short Form": "7", "Long Form": "803"}, + | "Callsite": "7", | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, From d576e76bbaa818480d31d2b8fbbe4b15718307d9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 3 Dec 2015 11:37:34 -0800 Subject: [PATCH 1650/1824] [MINOR][ML] Use coefficients replace weights Use ```coefficients``` replace ```weights```, I wish they are the last two. mengxr Author: Yanbo Liang Closes #10065 from yanboliang/coefficients. --- python/pyspark/ml/classification.py | 2 +- python/pyspark/ml/regression.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4a2982e2047f..5599b8f3ecd8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -49,7 +49,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) - >>> model.weights + >>> model.coefficients DenseVector([5.5...]) >>> model.intercept -2.68... diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 944e648ec880..a0bb8ceed886 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -40,7 +40,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Linear regression. The learning objective is to minimize the squared error, with regularization. - The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^ + The specific squared error loss function used is: L = 1/2n ||A coefficients - y||^2^ This support multiple types of regularization: - none (a.k.a. ordinary least squares) From ad7cea6f776a39801d6bb5bb829d1800b175b2ab Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Thu, 3 Dec 2015 11:59:10 -0800 Subject: [PATCH 1651/1824] [SPARK-12107][EC2] Update spark-ec2 versions I haven't created a JIRA. If we absolutely need one I'll do it, but I'm fine with not getting mentioned in the release notes if that's the only purpose it'll serve. cc marmbrus - We should include this in 1.6-RC2 if there is one. I can open a second PR against branch-1.6 if necessary. Author: Nicholas Chammas Closes #10109 from nchammas/spark-ec2-versions. --- ec2/spark_ec2.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 84a950c9f652..19d5980560fe 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -51,7 +51,7 @@ raw_input = input xrange = range -SPARK_EC2_VERSION = "1.5.0" +SPARK_EC2_VERSION = "1.6.0" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -72,7 +72,10 @@ "1.3.1", "1.4.0", "1.4.1", - "1.5.0" + "1.5.0", + "1.5.1", + "1.5.2", + "1.6.0", ]) SPARK_TACHYON_MAP = { @@ -87,7 +90,10 @@ "1.3.1": "0.5.0", "1.4.0": "0.6.4", "1.4.1": "0.6.4", - "1.5.0": "0.7.1" + "1.5.0": "0.7.1", + "1.5.1": "0.7.1", + "1.5.2": "0.7.1", + "1.6.0": "0.8.2", } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION From a02d47277379e1e82d0ee41b2205434f9ffbc3e5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 3 Dec 2015 12:00:09 -0800 Subject: [PATCH 1652/1824] [FLAKY-TEST-FIX][STREAMING][TEST] Make sure StreamingContexts are shutdown after test Author: Tathagata Das Closes #10124 from tdas/InputStreamSuite-flaky-test. --- .../spark/streaming/InputStreamsSuite.scala | 122 +++++++++--------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 047e38ef9099..3a3176b91b1e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -206,28 +206,28 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val numTotalRecords = numThreads * numRecordsPerThread val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) MultiThreadTestReceiver.haveAllThreadsFinished = false - - // set up the network stream using the test receiver - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.receiverStream[Int](testReceiver) - val countStream = networkStream.count val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] - val outputStream = new TestOutputStream(countStream, outputBuffer) def output: ArrayBuffer[Long] = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Let the data from the receiver be received - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val startTime = System.currentTimeMillis() - while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && - System.currentTimeMillis() - startTime < 5000) { - Thread.sleep(100) - clock.advance(batchDuration.milliseconds) + + // set up the network stream using the test receiver + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val networkStream = ssc.receiverStream[Int](testReceiver) + val countStream = networkStream.count + + val outputStream = new TestOutputStream(countStream, outputBuffer) + outputStream.register() + ssc.start() + + // Let the data from the receiver be received + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val startTime = System.currentTimeMillis() + while ((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && + System.currentTimeMillis() - startTime < 5000) { + Thread.sleep(100) + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") @@ -239,30 +239,30 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("queue input stream - oneAtATime = true") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = true) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) - def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = input.map(Seq(_)) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - val inputIterator = input.toIterator - for (i <- 0 until input.size) { - // Enqueue more than 1 item per tick but they should dequeue one at a time - inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = true) + val outputStream = new TestOutputStream(queueStream, outputBuffer) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + val inputIterator = input.toIterator + for (i <- 0 until input.size) { + // Enqueue more than 1 item per tick but they should dequeue one at a time + inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + } + Thread.sleep(1000) } - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() // Verify whether data received was as expected logInfo("--------------------------------") @@ -282,33 +282,33 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("queue input stream - oneAtATime = false") { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val queue = new SynchronizedQueue[RDD[String]]() - val queueStream = ssc.queueStream(queue, oneAtATime = false) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(queueStream, outputBuffer) def output: ArrayBuffer[Seq[String]] = outputBuffer.filter(_.size > 0) - outputStream.register() - ssc.start() - - // Setup data queued into the stream - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq("1", "2", "3", "4", "5") val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5")) - // Enqueue the first 3 items (one by one), they should be merged in the next batch - val inputIterator = input.toIterator - inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) - - // Enqueue the remaining items (again one by one), merged in the final batch - inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) - clock.advance(batchDuration.milliseconds) - Thread.sleep(1000) - logInfo("Stopping context") - ssc.stop() + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = false) + val outputStream = new TestOutputStream(queueStream, outputBuffer) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + + // Enqueue the first 3 items (one by one), they should be merged in the next batch + val inputIterator = input.toIterator + inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) + + // Enqueue the remaining items (again one by one), merged in the final batch + inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.advance(batchDuration.milliseconds) + Thread.sleep(1000) + } // Verify whether data received was as expected logInfo("--------------------------------") From 2213441e5e0fba01e05826257604aa427cdf2598 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Thu, 3 Dec 2015 13:25:20 -0800 Subject: [PATCH 1653/1824] [SPARK-12019][SPARKR] Support character vector for sparkR.init(), check param and fix doc and add tests. Spark submit expects comma-separated list Author: felixcheung Closes #10034 from felixcheung/sparkrinitdoc. --- R/pkg/R/client.R | 10 ++++-- R/pkg/R/sparkR.R | 56 ++++++++++++++++++++++----------- R/pkg/R/utils.R | 5 +++ R/pkg/inst/tests/test_client.R | 9 ++++++ R/pkg/inst/tests/test_context.R | 20 ++++++++++++ 5 files changed, 79 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index c811d1dac3bd..25e99390a9c8 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -44,12 +44,16 @@ determineSparkSubmitBin <- function() { } generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + jars <- paste0(jars, collapse = ",") if (jars != "") { - jars <- paste("--jars", jars) + # construct the jars argument with a space between --jars and comma-separated values + jars <- paste0("--jars ", jars) } - if (!identical(packages, "")) { - packages <- paste("--packages", packages) + packages <- paste0(packages, collapse = ",") + if (packages != "") { + # construct the packages argument with a space between --packages and comma-separated values + packages <- paste0("--packages ", packages) } combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 7ff3fa628b9c..d2bfad553104 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -86,13 +86,13 @@ sparkR.stop <- function() { #' and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparkcontext-sqlcontext}. #' -#' @param master The Spark master URL. +#' @param master The Spark master URL #' @param appName Application name to register with cluster manager #' @param sparkHome Spark Home directory -#' @param sparkEnvir Named list of environment variables to set on worker nodes. -#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. -#' @param sparkJars Character string vector of jar files to pass to the worker nodes. -#' @param sparkPackages Character string vector of packages from spark-packages.org +#' @param sparkEnvir Named list of environment variables to set on worker nodes +#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors +#' @param sparkJars Character vector of jar files to pass to the worker nodes +#' @param sparkPackages Character vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -102,7 +102,9 @@ sparkR.stop <- function() { #' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark", #' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), -#' c("jarfile1.jar","jarfile2.jar")) +#' c("one.jar", "two.jar", "three.jar"), +#' c("com.databricks:spark-avro_2.10:2.0.1", +#' "com.databricks:spark-csv_2.10:1.3.0")) #'} sparkR.init <- function( @@ -120,15 +122,8 @@ sparkR.init <- function( return(get(".sparkRjsc", envir = .sparkREnv)) } - jars <- suppressWarnings(normalizePath(as.character(sparkJars))) - - # Classpath separator is ";" on Windows - # URI needs four /// as from http://stackoverflow.com/a/18522792 - if (.Platform$OS.type == "unix") { - uriSep <- "//" - } else { - uriSep <- "////" - } + jars <- processSparkJars(sparkJars) + packages <- processSparkPackages(sparkPackages) sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) @@ -145,7 +140,7 @@ sparkR.init <- function( sparkHome = sparkHome, jars = jars, sparkSubmitOpts = submitOps, - packages = sparkPackages) + packages = packages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { @@ -195,8 +190,14 @@ sparkR.init <- function( paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } - nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- lapply(nonEmptyJars, + # Classpath separator is ";" on Windows + # URI needs four /// as from http://stackoverflow.com/a/18522792 + if (.Platform$OS.type == "unix") { + uriSep <- "//" + } else { + uriSep <- "////" + } + localJarPaths <- lapply(jars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs @@ -366,3 +367,22 @@ getClientModeSparkSubmitOpts <- function(submitOps, sparkEnvirMap) { # --option must be before the application class "sparkr-shell" in submitOps paste0(paste0(envirToOps, collapse = ""), submitOps) } + +# Utility function that handles sparkJars argument, and normalize paths +processSparkJars <- function(jars) { + splittedJars <- splitString(jars) + if (length(splittedJars) > length(jars)) { + warning("sparkJars as a comma-separated string is deprecated, use character vector instead") + } + normalized <- suppressWarnings(normalizePath(splittedJars)) + normalized +} + +# Utility function that handles sparkPackages argument +processSparkPackages <- function(packages) { + splittedPackages <- splitString(packages) + if (length(splittedPackages) > length(packages)) { + warning("sparkPackages as a comma-separated string is deprecated, use character vector instead") + } + splittedPackages +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 45c77a86c958..43105aaa3842 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -636,3 +636,8 @@ assignNewEnv <- function(data) { } env } + +# Utility function to split by ',' and whitespace, remove empty tokens +splitString <- function(input) { + Filter(nzchar, unlist(strsplit(input, ",|\\s"))) +} diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R index 8a20991f89af..a0664f32f31c 100644 --- a/R/pkg/inst/tests/test_client.R +++ b/R/pkg/inst/tests/test_client.R @@ -34,3 +34,12 @@ test_that("no package specified doesn't add packages flag", { test_that("multiple packages don't produce a warning", { expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) }) + +test_that("sparkJars sparkPackages as character vectors", { + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", + c("com.databricks:spark-avro_2.10:2.0.1", + "com.databricks:spark-csv_2.10:1.3.0")) + expect_match(args, "--jars one.jar,two.jar,three.jar") + expect_match(args, + "--packages com.databricks:spark-avro_2.10:2.0.1,com.databricks:spark-csv_2.10:1.3.0") +}) diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R index 80c1b89a4c62..1707e314beff 100644 --- a/R/pkg/inst/tests/test_context.R +++ b/R/pkg/inst/tests/test_context.R @@ -92,3 +92,23 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli " --driver-memory 4g sparkr-shell2")) # nolint end }) + +test_that("sparkJars sparkPackages as comma-separated strings", { + expect_warning(processSparkJars(" a, b ")) + jars <- suppressWarnings(processSparkJars(" a, b ")) + expect_equal(jars, c("a", "b")) + + jars <- suppressWarnings(processSparkJars(" abc ,, def ")) + expect_equal(jars, c("abc", "def")) + + jars <- suppressWarnings(processSparkJars(c(" abc ,, def ", "", "xyz", " ", "a,b"))) + expect_equal(jars, c("abc", "def", "xyz", "a", "b")) + + p <- processSparkPackages(c("ghi", "lmn")) + expect_equal(p, c("ghi", "lmn")) + + # check normalizePath + f <- dir()[[1]] + expect_that(processSparkJars(f), not(gives_warning())) + expect_match(processSparkJars(f), f) +}) From f434f36d508eb4dcade70871611fc022ae0feb56 Mon Sep 17 00:00:00 2001 From: Anderson de Andrade Date: Thu, 3 Dec 2015 16:37:00 -0800 Subject: [PATCH 1654/1824] [SPARK-12056][CORE] Create a TaskAttemptContext only after calling setConf. TaskAttemptContext's constructor will clone the configuration instead of referencing it. Calling setConf after creating TaskAttemptContext makes any changes to the configuration made inside setConf unperceived by RecordReader instances. As an example, Titan's InputFormat will change conf when calling setConf. They wrap their InputFormat around Cassandra's ColumnFamilyInputFormat, and append Cassandra's configuration. This change fixes the following error when using Titan's CassandraInputFormat with Spark: *java.lang.RuntimeException: org.apache.thrift.protocol.TProtocolException: Required field 'keyspace' was not present! Struct: set_key space_args(keyspace:null)* There's a discussion of this error here: https://groups.google.com/forum/#!topic/aureliusgraphs/4zpwyrYbGAE Author: Anderson de Andrade Closes #10046 from adeandrade/newhadooprdd-fix. --- core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index d1960990da0f..86f38ae836b2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -138,14 +138,14 @@ class NewHadoopRDD[K, V]( } inputMetrics.setBytesReadCallback(bytesReadCallback) - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) From b6e9963ee4bf0ffb62c8e9829a551bcdc31e12e3 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Thu, 3 Dec 2015 16:39:12 -0800 Subject: [PATCH 1655/1824] [SPARK-11206] Support SQL UI on the history server (resubmit) Resubmit #9297 and #9991 On the live web UI, there is a SQL tab which provides valuable information for the SQL query. But once the workload is finished, we won't see the SQL tab on the history server. It will be helpful if we support SQL UI on the history server so we can analyze it even after its execution. To support SQL UI on the history server: 1. I added an onOtherEvent method to the SparkListener trait and post all SQL related events to the same event bus. 2. Two SQL events SparkListenerSQLExecutionStart and SparkListenerSQLExecutionEnd are defined in the sql module. 3. The new SQL events are written to event log using Jackson. 4. A new trait SparkHistoryListenerFactory is added to allow the history server to feed events to the SQL history listener. The SQL implementation is loaded at runtime using java.util.ServiceLoader. Author: Carson Wang Closes #10061 from carsonwang/SqlHistoryUI. --- .rat-excludes | 1 + .../org/apache/spark/JavaSparkListener.java | 3 + .../apache/spark/SparkFirehoseListener.java | 4 + .../scheduler/EventLoggingListener.scala | 4 + .../spark/scheduler/SparkListener.scala | 24 ++- .../spark/scheduler/SparkListenerBus.scala | 1 + .../scala/org/apache/spark/ui/SparkUI.scala | 16 +- .../org/apache/spark/util/JsonProtocol.scala | 11 +- ...park.scheduler.SparkHistoryListenerFactory | 1 + .../org/apache/spark/sql/SQLContext.scala | 19 ++- .../spark/sql/execution/SQLExecution.scala | 24 +-- .../spark/sql/execution/SparkPlanInfo.scala | 46 ++++++ .../sql/execution/metric/SQLMetricInfo.scala | 30 ++++ .../sql/execution/metric/SQLMetrics.scala | 56 ++++--- .../sql/execution/ui/ExecutionPage.scala | 4 +- .../spark/sql/execution/ui/SQLListener.scala | 139 ++++++++++++------ .../spark/sql/execution/ui/SQLTab.scala | 12 +- .../sql/execution/ui/SparkPlanGraph.scala | 20 +-- .../execution/metric/SQLMetricsSuite.scala | 4 +- .../sql/execution/ui/SQLListenerSuite.scala | 44 +++--- .../spark/sql/test/SharedSQLContext.scala | 1 + 21 files changed, 329 insertions(+), 135 deletions(-) create mode 100644 sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala diff --git a/.rat-excludes b/.rat-excludes index 08fba6d351d6..7262c960ed6b 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,4 +82,5 @@ INDEX gen-java.* .*avpr org.apache.spark.sql.sources.DataSourceRegister +org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java index fa9acf0a15b8..23bc9a2e8172 100644 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -82,4 +82,7 @@ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } @Override public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } + @Override + public void onOtherEvent(SparkListenerEvent event) { } + } diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 1214d05ba606..e6b24afd88ad 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,4 +118,8 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { onEvent(blockUpdated); } + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 000a021a528c..eaa07acc5132 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,10 @@ private[spark] class EventLoggingListener( // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } + override def onOtherEvent(event: SparkListenerEvent): Unit = { + logEvent(event, flushLogger = true) + } + /** * Stop logging events. The event log file will be renamed so that it loses the * ".inprogress" suffix. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 896f1743332f..075a7f13172d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -22,15 +22,19 @@ import java.util.Properties import scala.collection.Map import scala.collection.mutable -import org.apache.spark.{Logging, TaskEndReason} +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.{Logging, SparkConf, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.ui.SparkUI @DeveloperApi -sealed trait SparkListenerEvent +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent @DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) @@ -130,6 +134,17 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent */ private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +/** + * Interface for creating history listeners defined in other modules like SQL, which are used to + * rebuild the history UI. + */ +private[spark] trait SparkHistoryListenerFactory { + /** + * Create listeners used to rebuild the history UI. + */ + def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] +} + /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal @@ -223,6 +238,11 @@ trait SparkListener { * Called when the driver receives a block update info. */ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } + + /** + * Called when other events like SQL-specific events are posted. + */ + def onOtherEvent(event: SparkListenerEvent) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 04afde33f5aa..95722a07144e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,7 @@ private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkLi case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) case logStart: SparkListenerLogStart => // ignore event log metadata + case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 4608bce202ec..8da6884a3853 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,10 +17,13 @@ package org.apache.spark.ui -import java.util.Date +import java.util.{Date, ServiceLoader} + +import scala.collection.JavaConverters._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -154,7 +157,16 @@ private[spark] object SparkUI { appName: String, basePath: String, startTime: Long): SparkUI = { - create(None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + val sparkUI = create( + None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime) + + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, sparkUI) + listeners.foreach(listenerBus.addListener) + } + sparkUI } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 2d2bd90eb339..cb0f1bf79f3d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,19 +19,21 @@ package org.apache.spark.util import java.util.{Properties, UUID} -import org.apache.spark.scheduler.cluster.ExecutorInfo - import scala.collection.JavaConverters._ import scala.collection.Map +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ /** @@ -54,6 +56,8 @@ private[spark] object JsonProtocol { private implicit val format = DefaultFormats + private val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ @@ -96,6 +100,7 @@ private[spark] object JsonProtocol { executorMetricsUpdateToJson(metricsUpdate) case blockUpdated: SparkListenerBlockUpdated => throw new MatchError(blockUpdated) // TODO(ekl) implement this + case _ => parse(mapper.writeValueAsString(event)) } } @@ -506,6 +511,8 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) + .asInstanceOf[SparkListenerEvent] } } diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory new file mode 100644 index 000000000000..507100be9096 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.scheduler.SparkHistoryListenerFactory @@ -0,0 +1 @@ +org.apache.spark.sql.execution.ui.SQLHistoryListenerFactory diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4e2625086837..db286ea8700b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1245,6 +1245,7 @@ class SQLContext private[sql]( sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { SQLContext.clearInstantiatedContext() + SQLContext.clearSqlListener() } }) @@ -1272,6 +1273,8 @@ object SQLContext { */ @transient private val instantiatedContext = new AtomicReference[SQLContext]() + @transient private val sqlListener = new AtomicReference[SQLListener]() + /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -1316,6 +1319,10 @@ object SQLContext { Option(instantiatedContext.get()) } + private[sql] def clearSqlListener(): Unit = { + sqlListener.set(null) + } + /** * Changes the SQLContext that will be returned in this thread and its children when * SQLContext.getOrCreate() is called. This can be used to ensure that a given thread receives @@ -1364,9 +1371,13 @@ object SQLContext { * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. */ private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - val listener = new SQLListener(sc.conf) - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - listener + if (sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } + } + sqlListener.get() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 1422e15549c9..34971986261c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.SparkPlanGraph +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, + SparkListenerSQLExecutionEnd} import org.apache.spark.util.Utils private[sql] object SQLExecution { @@ -45,25 +46,14 @@ private[sql] object SQLExecution { sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { val callSite = Utils.getCallSite() - sqlContext.listener.onExecutionStart( - executionId, - callSite.shortForm, - callSite.longForm, - queryExecution.toString, - SparkPlanGraph(queryExecution.executedPlan), - System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { body } finally { - // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd. - // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new - // SQL event types to SparkListener since it's a public API, we cannot guarantee that. - // - // SQLListener should handle the case that onExecutionEnd happens before onJobEnd. - // - // The worst case is onExecutionEnd may happen before onJobStart when the listener thread - // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable. - sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis()) + sqlContext.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { sc.setLocalProperty(EXECUTION_ID_KEY, null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala new file mode 100644 index 000000000000..486ce34064e4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.util.Utils + +/** + * :: DeveloperApi :: + * Stores information about a SQL SparkPlan. + */ +@DeveloperApi +class SparkPlanInfo( + val nodeName: String, + val simpleString: String, + val children: Seq[SparkPlanInfo], + val metrics: Seq[SQLMetricInfo]) + +private[sql] object SparkPlanInfo { + + def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { + val metrics = plan.metrics.toSeq.map { case (key, metric) => + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, + Utils.getFormattedClassName(metric.param)) + } + val children = plan.children.map(fromSparkPlan) + + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala new file mode 100644 index 000000000000..2708219ad348 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Stores information about a SQL Metric. + */ +@DeveloperApi +class SQLMetricInfo( + val name: String, + val accumulatorId: Long, + val metricParam: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 1c253e3942e9..6c0f6f8a52dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -104,21 +104,39 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) } +private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) + +private object StaticsLongSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.bytesToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + private[sql] object SQLMetrics { private def createLongMetric( sc: SparkContext, name: String, - stringValue: Seq[Long] => String, - initialValue: Long): LongSQLMetric = { - val param = new LongSQLMetricParam(stringValue, initialValue) + param: LongSQLMetricParam): LongSQLMetric = { val acc = new LongSQLMetric(name, param) sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc } def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, _.sum.toString, 0L) + createLongMetric(sc, name, LongSQLMetricParam) } /** @@ -126,31 +144,25 @@ private[sql] object SQLMetrics { * spill size, etc. */ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = { - val stringValue = (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - } // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L) + createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) + } + + def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { + val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) + val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) + val metricParam = metricParamName match { + case `longSQLMetricParam` => LongSQLMetricParam + case `staticsSQLMetricParam` => StaticsLongSQLMetricParam + } + metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] } /** * A metric that its value will be ignored. Use this one when we need a metric parameter but don't * care about the value. */ - val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L)) + val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e74d6fb396e1..c74ad4040699 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, Unparsed} - -import org.apache.commons.lang3.StringEscapeUtils +import scala.xml.Node import org.apache.spark.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5a072de400b6..e19a1e3e5851 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,11 +19,34 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.ui.SparkUI + +@DeveloperApi +case class SparkListenerSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) + extends SparkListenerEvent + +private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { + + override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { + List(new SQLHistoryListener(conf, sparkUI)) + } +} private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { @@ -118,7 +141,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi override def onExecutorMetricsUpdate( executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { - updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(), + finishTask = false) } } @@ -140,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics, + taskEnd.taskMetrics.accumulatorUpdates(), finishTask = true) } @@ -148,15 +172,12 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi * Update the accumulator values of a task with the latest metrics for this task. This is called * every time we receive an executor heartbeat or when a task finishes. */ - private def updateTaskAccumulatorValues( + protected def updateTaskAccumulatorValues( taskId: Long, stageId: Int, stageAttemptID: Int, - metrics: TaskMetrics, + accumulatorUpdates: Map[Long, Any], finishTask: Boolean): Unit = { - if (metrics == null) { - return - } _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -174,9 +195,9 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi case Some(taskMetrics) => if (finishTask) { taskMetrics.finished = true - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else if (!taskMetrics.finished) { - taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + taskMetrics.accumulatorUpdates = accumulatorUpdates } else { // If a task is finished, we should not override with accumulator updates from // heartbeat reports @@ -185,7 +206,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi // TODO Now just set attemptId to 0. Should fix here when we can get the attempt // id from SparkListenerExecutorMetricsUpdate stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( - attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + attemptId = 0, finished = finishTask, accumulatorUpdates) } } case None => @@ -193,38 +214,40 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } - def onExecutionStart( - executionId: Long, - description: String, - details: String, - physicalPlanDescription: String, - physicalPlanGraph: SparkPlanGraph, - time: Long): Unit = { - val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => - node.metrics.map(metric => metric.accumulatorId -> metric) - } - - val executionUIData = new SQLExecutionUIData(executionId, description, details, - physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) - synchronized { - activeExecutions(executionId) = executionUIData - _executionIdToData(executionId) = executionUIData - } - } - - def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { - _executionIdToData.get(executionId).foreach { executionUIData => - executionUIData.completionTime = Some(time) - if (!executionUIData.hasRunningJobs) { - // onExecutionEnd happens after all "onJobEnd"s - // So we should update the execution lists. - markExecutionFinished(executionId) - } else { - // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. - // Then we don't if the execution is successful, so let the last onJobEnd updates the - // execution lists. + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerSQLExecutionStart(executionId, description, details, + physicalPlanDescription, sparkPlanInfo, time) => + val physicalPlanGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + val executionUIData = new SQLExecutionUIData( + executionId, + description, + details, + physicalPlanDescription, + physicalPlanGraph, + sqlPlanMetrics.toMap, + time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + case SparkListenerSQLExecutionEnd(executionId, time) => synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } } } + case _ => // Ignore } private def markExecutionFinished(executionId: Long): Unit = { @@ -289,6 +312,38 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } +private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) + extends SQLListener(conf) { + + private var sqlTabAttached = false + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + // Do nothing + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskInfo.accumulables.map { acc => + (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong)) + }.toMap, + finishTask = true) + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case _: SparkListenerSQLExecutionStart => + if (!sqlTabAttached) { + new SQLTab(this, sparkUI) + sqlTabAttached = true + } + super.onOtherEvent(event) + case _ => super.onOtherEvent(event) + } +} + /** * Represent all necessary data for an execution that will be used in Web UI. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 9c27944d42fc..4f50b2ecdc8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.ui -import java.util.concurrent.atomic.AtomicInteger - import org.apache.spark.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) - extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { + extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI @@ -35,13 +33,5 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) } private[sql] object SQLTab { - private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" - - private val nextTabId = new AtomicInteger(0) - - private def nextTabName: String = { - val nextId = nextTabId.getAndIncrement() - if (nextId == 0) "SQL" else s"SQL$nextId" - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index f1fce5478a3f..7af0ff09c5c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.metric.SQLMetrics /** * A graph used for storing information of an executionPlan of DataFrame. @@ -48,27 +48,27 @@ private[sql] object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. */ - def apply(plan: SparkPlan): SparkPlanGraph = { + def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() - buildSparkPlanGraphNode(plan, nodeIdGenerator, nodes, edges) + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges) new SparkPlanGraph(nodes, edges) } private def buildSparkPlanGraphNode( - plan: SparkPlan, + planInfo: SparkPlanInfo, nodeIdGenerator: AtomicLong, nodes: mutable.ArrayBuffer[SparkPlanGraphNode], edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = { - val metrics = plan.metrics.toSeq.map { case (key, metric) => - SQLPlanMetric(metric.name.getOrElse(key), metric.id, - metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]) + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) nodes += node - val childrenNodes = plan.children.map( + val childrenNodes = planInfo.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) for (child <- childrenNodes) { edges += SparkPlanGraphEdge(child.id, node.id) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 82867ab4967b..4f2cad19bfb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -26,6 +26,7 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -82,7 +83,8 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => + val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( + df.queryExecution.executedPlan)).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => val nodeMetrics = node.metrics.map { metric => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index c15aac775096..12a4e1356fed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -21,10 +21,10 @@ import java.util.Properties import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.sql.test.SharedSQLContext class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { @@ -82,7 +82,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val executionId = 0 val df = createTestDataFrame val accumulatorIds = - SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)) + .nodes.flatMap(_.metrics.map(_.accumulatorId)) // Assume all accumulators are long var accumulatorValue = 0L val accumulatorUpdates = accumulatorIds.map { id => @@ -90,13 +91,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (id, accumulatorValue) }.toMap - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) val executionUIData = listener.executionIdToData(0) @@ -206,7 +207,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), JobSucceeded )) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) assert(executionUIData.runningJobs.isEmpty) assert(executionUIData.succeededJobs === Seq(0)) @@ -219,19 +221,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -248,13 +251,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -271,7 +274,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { time = System.currentTimeMillis(), stageInfos = Nil, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 1, time = System.currentTimeMillis(), @@ -288,19 +292,20 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame - listener.onExecutionStart( + listener.onOtherEvent(SparkListenerSQLExecutionStart( executionId, "test", "test", df.queryExecution.toString, - SparkPlanGraph(df.queryExecution.executedPlan), - System.currentTimeMillis()) + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), stageInfos = Seq.empty, createProperties(executionId))) - listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onOtherEvent(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) listener.onJobEnd(SparkListenerJobEnd( jobId = 0, time = System.currentTimeMillis(), @@ -338,6 +343,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly val sc = new SparkContext(conf) try { + SQLContext.clearSqlListener() val sqlContext = new SQLContext(sc) import sqlContext.implicits._ // Run 100 successful executions and 100 failed executions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 963d10eed62e..e7b376548787 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -42,6 +42,7 @@ trait SharedSQLContext extends SQLTestUtils { * Initialize the [[TestSQLContext]]. */ protected override def beforeAll(): Unit = { + SQLContext.clearSqlListener() if (_ctx == null) { _ctx = new TestSQLContext } From 5011f264fb53705c528250bd055acbc2eca2baaa Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Thu, 3 Dec 2015 21:11:10 -0800 Subject: [PATCH 1656/1824] [SPARK-12104][SPARKR] collect() does not handle multiple columns with same name. Author: Sun Rui Closes #10118 from sun-rui/SPARK-12104. --- R/pkg/R/DataFrame.R | 8 ++++---- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a82ded9c51fa..81b4e6b91d8a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -822,21 +822,21 @@ setMethod("collect", # Get a column of complex type returns a list. # Get a cell from a column of complex type returns a list instead of a vector. col <- listCols[[colIndex]] - colName <- dtypes[[colIndex]][[1]] if (length(col) <= 0) { - df[[colName]] <- col + df[[colIndex]] <- col } else { colType <- dtypes[[colIndex]][[2]] # Note that "binary" columns behave like complex types. if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) stopifnot(class(vec) != "list") - df[[colName]] <- vec + df[[colIndex]] <- vec } else { - df[[colName]] <- col + df[[colIndex]] <- col } } } + names(df) <- names(x) df } }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 92ec82096c6d..1e7cb5409970 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -530,6 +530,11 @@ test_that("collect() returns a data.frame", { expect_equal(names(rdf)[1], "age") expect_equal(nrow(rdf), 0) expect_equal(ncol(rdf), 2) + + # collect() correctly handles multiple columns with same name + df <- createDataFrame(sqlContext, list(list(1, 2)), schema = c("name", "name")) + ldf <- collect(df) + expect_equal(names(ldf), c("name", "name")) }) test_that("limit() returns DataFrame with the correct number of rows", { @@ -1197,6 +1202,7 @@ test_that("join() and merge() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) expect_equal(count(joined), 12) + expect_equal(names(collect(joined)), c("age", "name", "name", "test")) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) From 4106d80fb6a16713a6cd2f15ab9d60f2527d9be5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 Dec 2015 01:42:29 -0800 Subject: [PATCH 1657/1824] [SPARK-12122][STREAMING] Prevent batches from being submitted twice after recovering StreamingContext from checkpoint Author: Tathagata Das Closes #10127 from tdas/SPARK-12122. --- .../org/apache/spark/streaming/scheduler/JobGenerator.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 2de035d166e7..8dfdc1f57b40 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -220,7 +220,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", ")) // Reschedule jobs for these times - val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) + val timesToReschedule = (pendingTimes ++ downTimes).filter { _ < restartTime } + .distinct.sorted(Time.ordering) logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) timesToReschedule.foreach { time => From 17e4e021ae7fdf5e4dd05a0473faa529e3e80dbb Mon Sep 17 00:00:00 2001 From: kaklakariada Date: Fri, 4 Dec 2015 14:43:16 +0000 Subject: [PATCH 1658/1824] Add links howto to setup IDEs for developing spark These links make it easier for new developers to work with Spark in their IDE. Author: kaklakariada Closes #10104 from kaklakariada/readme-developing-ide-gettting-started. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c0d6a946035a..d5804d1a20b4 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). +For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) +and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). ## Interactive Scala Shell From 95296d9b1ad1d9e9396d7dfd0015ef27ce1cf341 Mon Sep 17 00:00:00 2001 From: Nong Date: Fri, 4 Dec 2015 10:01:20 -0800 Subject: [PATCH 1659/1824] [SPARK-12089] [SQL] Fix memory corrupt due to freeing a page being referenced When the spillable sort iterator was spilled, it was mistakenly keeping the last page in memory rather than the current page. This causes the current record to get corrupted. Author: Nong Closes #10142 from nongli/spark-12089. --- .../util/collection/unsafe/sort/UnsafeExternalSorter.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5a97f4f11340..79d74b23ceae 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -443,6 +443,7 @@ public long spill() throws IOException { UnsafeInMemorySorter.SortedIterator inMemIterator = ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); while (inMemIterator.hasNext()) { @@ -458,9 +459,11 @@ public long spill() throws IOException { long released = 0L; synchronized (UnsafeExternalSorter.this) { - // release the pages except the one that is used + // release the pages except the one that is used. There can still be a caller that + // is accessing the current record. We free this page in that caller's next loadNext() + // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) { + if (!loaded || page.getBaseObject() != upstream.getBaseObject()) { released += page.size(); freePage(page); } else { From d0d82227785dcd6c49a986431c476c7838a9541c Mon Sep 17 00:00:00 2001 From: Dmitry Erastov Date: Fri, 4 Dec 2015 12:03:45 -0800 Subject: [PATCH 1660/1824] [SPARK-6990][BUILD] Add Java linting script; fix minor warnings This replaces https://github.com/apache/spark/pull/9696 Invoke Checkstyle and print any errors to the console, failing the step. Use Google's style rules modified according to https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide Some important checks are disabled (see TODOs in `checkstyle.xml`) due to multiple violations being present in the codebase. Suggest fixing those TODOs in a separate PR(s). More on Checkstyle can be found on the [official website](http://checkstyle.sourceforge.net/). Sample output (from [build 46345](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/46345/consoleFull)) (duplicated because I run the build twice with different profiles): > Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java:[217,7] (coding) MissingSwitchDefault: switch without "default" clause. > [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:[198,10] (modifier) ModifierOrder: 'protected' modifier out of order with the JLS suggestions. > [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java:[217,7] (coding) MissingSwitchDefault: switch without "default" clause. > [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:[198,10] (modifier) ModifierOrder: 'protected' modifier out of order with the JLS suggestions. > [error] running /home/jenkins/workspace/SparkPullRequestBuilder2/dev/lint-java ; received return code 1 Also fix some of the minor violations that didn't require sweeping changes. Apologies for the previous botched PRs - I finally figured out the issue. cr: JoshRosen, pwendell > I state that the contribution is my original work, and I license the work to the project under the project's open source license. Author: Dmitry Erastov Closes #9867 from dskrvk/master. --- checkstyle-suppressions.xml | 33 ++++ checkstyle.xml | 164 ++++++++++++++++++ .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../map/AbstractBytesToBytesMapSuite.java | 4 +- dev/lint-java | 30 ++++ dev/run-tests-jenkins.py | 1 + dev/run-tests.py | 7 + dev/sparktestsupport/__init__.py | 1 + .../examples/ml/JavaSimpleParamsExample.java | 2 +- .../spark/examples/mllib/JavaLDAExample.java | 3 +- ...ultiLabelClassificationMetricsExample.java | 12 +- ...ulticlassClassificationMetricsExample.java | 12 +- .../mllib/JavaRankingMetricsExample.java | 4 +- .../mllib/JavaRecommendationExample.java | 2 +- .../mllib/JavaRegressionMetricsExample.java | 3 +- .../streaming/JavaSqlNetworkWordCount.java | 4 +- .../ml/feature/JavaStringIndexerSuite.java | 6 +- .../spark/mllib/clustering/JavaLDASuite.java | 2 +- .../shuffle/ExternalShuffleBlockResolver.java | 2 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- pom.xml | 24 +++ .../execution/UnsafeExternalRowSorter.java | 2 +- .../spark/sql/types/SQLUserDefinedType.java | 2 +- .../SpecificParquetRecordReaderBase.java | 2 +- .../apache/spark/sql/hive/test/Complex.java | 86 ++++++--- .../spark/streaming/util/WriteAheadLog.java | 10 +- .../apache/spark/streaming/JavaAPISuite.java | 4 +- .../streaming/JavaTrackStateByKeySuite.java | 2 +- .../apache/spark/tags/ExtendedHiveTest.java | 1 + .../apache/spark/tags/ExtendedYarnTest.java | 1 + .../apache/spark/unsafe/types/UTF8String.java | 8 +- 31 files changed, 368 insertions(+), 70 deletions(-) create mode 100644 checkstyle-suppressions.xml create mode 100644 checkstyle.xml create mode 100755 dev/lint-java diff --git a/checkstyle-suppressions.xml b/checkstyle-suppressions.xml new file mode 100644 index 000000000000..9242be3d0357 --- /dev/null +++ b/checkstyle-suppressions.xml @@ -0,0 +1,33 @@ + + + + + + + + + diff --git a/checkstyle.xml b/checkstyle.xml new file mode 100644 index 000000000000..a493ee443c75 --- /dev/null +++ b/checkstyle.xml @@ -0,0 +1,164 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c91e88f31bf9..c16cbce9a0f6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -175,7 +175,7 @@ private SortedIterator(int numRecords) { this.position = 0; } - public SortedIterator clone () { + public SortedIterator clone() { SortedIterator iter = new SortedIterator(numRecords); iter.position = position; iter.baseObject = baseObject; diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index d87a1d2a56d9..a5c583f9f284 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -356,8 +356,8 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); final Iterator iter = map.iterator(); - final long key[] = new long[KEY_LENGTH / 8]; - final long value[] = new long[VALUE_LENGTH / 8]; + final long[] key = new long[KEY_LENGTH / 8]; + final long[] value = new long[VALUE_LENGTH / 8]; while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); Assert.assertTrue(loc.isDefined()); diff --git a/dev/lint-java b/dev/lint-java new file mode 100755 index 000000000000..fe8ab83d562d --- /dev/null +++ b/dev/lint-java @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 4f390ef1eaa3..7aecea25b209 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -119,6 +119,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_GENERAL"]: 'some tests', ERROR_CODES["BLOCK_RAT"]: 'RAT tests', ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests', + ERROR_CODES["BLOCK_JAVA_STYLE"]: 'Java style tests', ERROR_CODES["BLOCK_PYTHON_STYLE"]: 'Python style tests', ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', diff --git a/dev/run-tests.py b/dev/run-tests.py index 9e1abb069719..e7e10f1d8c72 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -198,6 +198,11 @@ def run_scala_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-scala")]) +def run_java_style_checks(): + set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + + def run_python_style_checks(): set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) @@ -522,6 +527,8 @@ def main(): # style checks if not changed_files or any(f.endswith(".scala") for f in changed_files): run_scala_style_checks() + if not changed_files or any(f.endswith(".java") for f in changed_files): + run_java_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 8ab6d9e37ca2..0e8032d13341 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -31,5 +31,6 @@ "BLOCK_SPARK_UNIT_TESTS": 18, "BLOCK_PYSPARK_UNIT_TESTS": 19, "BLOCK_SPARKR_UNIT_TESTS": 20, + "BLOCK_JAVA_STYLE": 21, "BLOCK_TIMEOUT": 124 } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 94beeced3d47..ea83e8fef9eb 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -77,7 +77,7 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - double thresholds[] = {0.45, 0.55}; + double[] thresholds = {0.45, 0.55}; paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index fd53c81cc497..de8e739ac925 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -41,8 +41,9 @@ public static void main(String[] args) { public Vector call(String s) { String[] sarray = s.trim().split(" "); double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) + for (int i = 0; i < sarray.length; i++) { values[i] = Double.parseDouble(sarray[i]); + } return Vectors.dense(values); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java index b54e1ea3f2bc..5ba01e0d0881 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -57,12 +57,12 @@ public static void main(String[] args) { // Stats by labels for (int i = 0; i < metrics.labels().length - 1; i++) { - System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision - (metrics.labels()[i])); - System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics - .labels()[i])); - System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure - (metrics.labels()[i])); + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision( + metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure( + metrics.labels()[i])); } // Micro stats diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java index 21f628fb51b6..5247c9c74861 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -74,12 +74,12 @@ public Tuple2 call(LabeledPoint p) { // Stats by labels for (int i = 0; i < metrics.labels().length; i++) { - System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision - (metrics.labels()[i])); - System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics - .labels()[i])); - System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure - (metrics.labels()[i])); + System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision( + metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall( + metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure( + metrics.labels()[i])); } //Weighted stats diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java index 7c4c97e74681..47ab3fc35824 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -120,8 +120,8 @@ public List call(Rating[] docs) { } } ); - JavaRDD, List>> relevantDocs = userMoviesList.join - (userRecommendedList).values(); + JavaRDD, List>> relevantDocs = userMoviesList.join( + userRecommendedList).values(); // Instantiate the metrics object RankingMetrics metrics = RankingMetrics.of(relevantDocs); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java index 1065fde953b9..c179e7578cdf 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -29,7 +29,7 @@ // $example off$ public class JavaRecommendationExample { - public static void main(String args[]) { + public static void main(String[] args) { // $example on$ SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); JavaSparkContext jsc = new JavaSparkContext(conf); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java index d2efc6bf9777..4e89dd0c37c5 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -43,8 +43,9 @@ public static void main(String[] args) { public LabeledPoint call(String line) { String[] parts = line.split(" "); double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) + for (int i = 1; i < parts.length - 1; i++) { v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + } return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 46562ddbbcb5..3515d7be45d3 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -112,8 +112,8 @@ public JavaRecord call(String word) { /** Lazily instantiated singleton instance of SQLContext */ class JavaSQLContextSingleton { - static private transient SQLContext instance = null; - static public SQLContext getInstance(SparkContext sparkContext) { + private static transient SQLContext instance = null; + public static SQLContext getInstance(SparkContext sparkContext) { if (instance == null) { instance = new SQLContext(sparkContext); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 6b2c48ef1c34..b2df79ba74fe 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -58,7 +58,7 @@ public void testStringIndexer() { createStructField("label", StringType, false) }); List data = Arrays.asList( - c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")); + cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); DataFrame dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() @@ -67,12 +67,12 @@ public void testStringIndexer() { DataFrame output = indexer.fit(dataset).transform(dataset); Assert.assertArrayEquals( - new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) }, output.orderBy("id").select("id", "labelIndex").collect()); } /** An alias for RowFactory.create. */ - private Row c(Object... values) { + private Row cr(Object... values) { return RowFactory.create(values); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 3fea359a3b46..225a216270b3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -144,7 +144,7 @@ public Boolean call(Tuple2 tuple2) { } @Test - public void OnlineOptimizerCompatibility() { + public void onlineOptimizerCompatibility() { int k = 3; double topicSmoothing = 1.2; double termSmoothing = 1.2; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0d4dd6afac76..e5cb68c8a4db 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -419,7 +419,7 @@ private static void storeVersion(DB db) throws IOException { public static class StoreVersion { - final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + static final byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); public final int major; public final int minor; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 19c870aebb02..f573d962fe36 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -61,7 +61,7 @@ public class SaslIntegrationSuite { // Use a long timeout to account for slow / overloaded build machines. In the normal case, // tests should finish way before the timeout expires. - private final static long TIMEOUT_MS = 10_000; + private static final long TIMEOUT_MS = 10_000; static TransportServer server; static TransportConf conf; diff --git a/pom.xml b/pom.xml index 234fd5dea1a6..16e656d11961 100644 --- a/pom.xml +++ b/pom.xml @@ -2256,6 +2256,30 @@ + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + false + false + true + false + ${basedir}/src/main/java + ${basedir}/src/test/java + checkstyle.xml + ${basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + + check + + + + org.apache.maven.plugins diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 3986d6e18f77..352002b3499a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -51,7 +51,7 @@ final class UnsafeExternalRowSorter { private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public static abstract class PrefixComputer { + public abstract static class PrefixComputer { abstract long computePrefix(InternalRow row); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index df64a878b6b3..1e4e5ede8cc1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -41,5 +41,5 @@ * Returns an instance of the UserDefinedType which can serialize and deserialize the user * class to and from Catalyst built-in types. */ - Class > udt(); + Class> udt(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 2ed30c1f5a8d..842dcb8c93dc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -195,7 +195,7 @@ protected static final class NullIntIterator extends IntIterator { * Creates a reader for definition and repetition levels, returning an optimized one if * the levels are not needed. */ - static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { try { if (maxLevel == 0) return new NullIntIterator(); diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index e010112bb932..4ef1f276d1bb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -489,6 +489,7 @@ public void setFieldValue(_Fields field, Object value) { } break; + default: } } @@ -512,6 +513,7 @@ public Object getFieldValue(_Fields field) { case M_STRING_STRING: return getMStringString(); + default: } throw new IllegalStateException(); } @@ -535,75 +537,91 @@ public boolean isSet(_Fields field) { return isSetLintString(); case M_STRING_STRING: return isSetMStringString(); + default: } throw new IllegalStateException(); } @Override public boolean equals(Object that) { - if (that == null) + if (that == null) { return false; - if (that instanceof Complex) + } + if (that instanceof Complex) { return this.equals((Complex)that); + } return false; } public boolean equals(Complex that) { - if (that == null) + if (that == null) { return false; + } boolean this_present_aint = true; boolean that_present_aint = true; if (this_present_aint || that_present_aint) { - if (!(this_present_aint && that_present_aint)) + if (!(this_present_aint && that_present_aint)) { return false; - if (this.aint != that.aint) + } + if (this.aint != that.aint) { return false; + } } boolean this_present_aString = true && this.isSetAString(); boolean that_present_aString = true && that.isSetAString(); if (this_present_aString || that_present_aString) { - if (!(this_present_aString && that_present_aString)) + if (!(this_present_aString && that_present_aString)) { return false; - if (!this.aString.equals(that.aString)) + } + if (!this.aString.equals(that.aString)) { return false; + } } boolean this_present_lint = true && this.isSetLint(); boolean that_present_lint = true && that.isSetLint(); if (this_present_lint || that_present_lint) { - if (!(this_present_lint && that_present_lint)) + if (!(this_present_lint && that_present_lint)) { return false; - if (!this.lint.equals(that.lint)) + } + if (!this.lint.equals(that.lint)) { return false; + } } boolean this_present_lString = true && this.isSetLString(); boolean that_present_lString = true && that.isSetLString(); if (this_present_lString || that_present_lString) { - if (!(this_present_lString && that_present_lString)) + if (!(this_present_lString && that_present_lString)) { return false; - if (!this.lString.equals(that.lString)) + } + if (!this.lString.equals(that.lString)) { return false; + } } boolean this_present_lintString = true && this.isSetLintString(); boolean that_present_lintString = true && that.isSetLintString(); if (this_present_lintString || that_present_lintString) { - if (!(this_present_lintString && that_present_lintString)) + if (!(this_present_lintString && that_present_lintString)) { return false; - if (!this.lintString.equals(that.lintString)) + } + if (!this.lintString.equals(that.lintString)) { return false; + } } boolean this_present_mStringString = true && this.isSetMStringString(); boolean that_present_mStringString = true && that.isSetMStringString(); if (this_present_mStringString || that_present_mStringString) { - if (!(this_present_mStringString && that_present_mStringString)) + if (!(this_present_mStringString && that_present_mStringString)) { return false; - if (!this.mStringString.equals(that.mStringString)) + } + if (!this.mStringString.equals(that.mStringString)) { return false; + } } return true; @@ -615,33 +633,39 @@ public int hashCode() { boolean present_aint = true; builder.append(present_aint); - if (present_aint) + if (present_aint) { builder.append(aint); + } boolean present_aString = true && (isSetAString()); builder.append(present_aString); - if (present_aString) + if (present_aString) { builder.append(aString); + } boolean present_lint = true && (isSetLint()); builder.append(present_lint); - if (present_lint) + if (present_lint) { builder.append(lint); + } boolean present_lString = true && (isSetLString()); builder.append(present_lString); - if (present_lString) + if (present_lString) { builder.append(lString); + } boolean present_lintString = true && (isSetLintString()); builder.append(present_lintString); - if (present_lintString) + if (present_lintString) { builder.append(lintString); + } boolean present_mStringString = true && (isSetMStringString()); builder.append(present_mStringString); - if (present_mStringString) + if (present_mStringString) { builder.append(mStringString); + } return builder.toHashCode(); } @@ -737,7 +761,9 @@ public String toString() { sb.append("aint:"); sb.append(this.aint); first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("aString:"); if (this.aString == null) { sb.append("null"); @@ -745,7 +771,9 @@ public String toString() { sb.append(this.aString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lint:"); if (this.lint == null) { sb.append("null"); @@ -753,7 +781,9 @@ public String toString() { sb.append(this.lint); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lString:"); if (this.lString == null) { sb.append("null"); @@ -761,7 +791,9 @@ public String toString() { sb.append(this.lString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("lintString:"); if (this.lintString == null) { sb.append("null"); @@ -769,7 +801,9 @@ public String toString() { sb.append(this.lintString); } first = false; - if (!first) sb.append(", "); + if (!first) { + sb.append(", "); + } sb.append("mStringString:"); if (this.mStringString == null) { sb.append("null"); diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 3738fc1a235c..2803cad8095d 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -37,26 +37,26 @@ public abstract class WriteAheadLog { * ensure that the written data is durable and readable (using the record handle) by the * time this function returns. */ - abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + public abstract WriteAheadLogRecordHandle write(ByteBuffer record, long time); /** * Read a written record based on the given record handle. */ - abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + public abstract ByteBuffer read(WriteAheadLogRecordHandle handle); /** * Read and return an iterator of all the records that have been written but not yet cleaned up. */ - abstract public Iterator readAll(); + public abstract Iterator readAll(); /** * Clean all the records that are older than the threshold time. It can wait for * the completion of the deletion. */ - abstract public void clean(long threshTime, boolean waitForCompletion); + public abstract void clean(long threshTime, boolean waitForCompletion); /** * Close this log and release any resources. */ - abstract public void close(); + public abstract void close(); } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 609bb4413b6b..9722c60bba1c 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1332,12 +1332,12 @@ public Optional call(List values, Optional state) { public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; - List> initial = Arrays.asList ( + List> initial = Arrays.asList( new Tuple2<>("california", 1), new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); - JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); + JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD(tmpRDD); List>> expected = Arrays.asList( Arrays.asList(new Tuple2<>("california", 5), diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java index eac4cdd14a68..89d0bb7b617e 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java @@ -95,7 +95,7 @@ public Double call(Optional one, State state) { JavaTrackStateDStream stateDstream2 = wordsDstream.trackStateByKey( - StateSpec. function(trackStateFunc2) + StateSpec.function(trackStateFunc2) .initialState(initialRDD) .numPartitions(10) .partitioner(new HashPartitioner(10)) diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java index 1b0c416b0fe4..83279e5e93c0 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java index 2a631bfc88cf..108300168e17 100644 --- a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java @@ -18,6 +18,7 @@ package org.apache.spark.tags; import java.lang.annotation.*; + import org.scalatest.TagAnnotation; @TagAnnotation diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 4bd3fd777207..5b6138680876 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -900,9 +900,9 @@ public int levenshteinDistance(UTF8String other) { m = swap; } - int p[] = new int[n + 1]; - int d[] = new int[n + 1]; - int swap[]; + int[] p = new int[n + 1]; + int[] d = new int[n + 1]; + int[] swap; int i, i_bytes, j, j_bytes, num_bytes_j, cost; @@ -965,7 +965,7 @@ public UTF8String soundex() { // first character must be a letter return this; } - byte sx[] = {'0', '0', '0', '0'}; + byte[] sx = {'0', '0', '0', '0'}; sx[0] = b; int sxi = 1; int idx = b - 'A'; From 302d68de87dbaf1974accf49de26fc01fc0eb089 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 4 Dec 2015 12:08:42 -0800 Subject: [PATCH 1661/1824] [SPARK-12058][STREAMING][KINESIS][TESTS] fix Kinesis python tests Python tests require access to the `KinesisTestUtils` file. When this file exists under src/test, python can't access it, since it is not available in the assembly jar. However, if we move KinesisTestUtils to src/main, we need to add the KinesisProducerLibrary as a dependency. In order to avoid this, I moved KinesisTestUtils to src/main, and extended it with ExtendedKinesisTestUtils which is under src/test that adds support for the KPL. cc zsxwing tdas Author: Burak Yavuz Closes #10050 from brkyvz/kinesis-py. --- .../streaming/kinesis/KinesisTestUtils.scala | 88 +++++++++---------- .../kinesis/KPLBasedKinesisTestUtils.scala | 72 +++++++++++++++ .../kinesis/KinesisBackedBlockRDDSuite.scala | 2 +- .../kinesis/KinesisStreamSuite.scala | 2 +- python/pyspark/streaming/tests.py | 1 - 5 files changed, 115 insertions(+), 50 deletions(-) rename extras/kinesis-asl/src/{test => main}/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala (82%) create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala similarity index 82% rename from extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala rename to extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 7487aa1c1263..0ace453ee928 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -31,13 +31,13 @@ import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.model._ -import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration, UserRecordResult} -import com.google.common.util.concurrent.{FutureCallback, Futures} import org.apache.spark.Logging /** - * Shared utility methods for performing Kinesis tests that actually transfer data + * Shared utility methods for performing Kinesis tests that actually transfer data. + * + * PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE! */ private[kinesis] class KinesisTestUtils extends Logging { @@ -54,7 +54,7 @@ private[kinesis] class KinesisTestUtils extends Logging { @volatile private var _streamName: String = _ - private lazy val kinesisClient = { + protected lazy val kinesisClient = { val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) client.setEndpoint(endpointUrl) client @@ -66,14 +66,12 @@ private[kinesis] class KinesisTestUtils extends Logging { new DynamoDB(dynamoDBClient) } - private lazy val kinesisProducer: KinesisProducer = { - val conf = new KinesisProducerConfiguration() - .setRecordMaxBufferedTime(1000) - .setMaxConnections(1) - .setRegion(regionName) - .setMetricsLevel("none") - - new KinesisProducer(conf) + protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + throw new UnsupportedOperationException("Aggregation is not supported through this code path") + } } def streamName: String = { @@ -104,41 +102,8 @@ private[kinesis] class KinesisTestUtils extends Logging { */ def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = { require(streamCreated, "Stream not yet created, call createStream() to create one") - val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() - - testData.foreach { num => - val str = num.toString - val data = ByteBuffer.wrap(str.getBytes()) - if (aggregate) { - val future = kinesisProducer.addUserRecord(streamName, str, data) - val kinesisCallBack = new FutureCallback[UserRecordResult]() { - override def onFailure(t: Throwable): Unit = {} // do nothing - - override def onSuccess(result: UserRecordResult): Unit = { - val shardId = result.getShardId - val seqNumber = result.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - } - - Futures.addCallback(future, kinesisCallBack) - kinesisProducer.flushSync() // make sure we send all data before returning the map - } else { - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(data) - .withPartitionKey(str) - - val putRecordResult = kinesisClient.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) - } - } - + val producer = getProducer(aggregate) + val shardIdToSeqNumbers = producer.sendData(streamName, testData) logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") shardIdToSeqNumbers.toMap } @@ -264,3 +229,32 @@ private[kinesis] object KinesisTestUtils { } } } + +/** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */ +private[kinesis] trait KinesisDataGenerator { + /** Sends the data to Kinesis and returns the metadata for everything that has been sent. */ + def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] +} + +private[kinesis] class SimpleDataGenerator( + client: AmazonKinesisClient) extends KinesisDataGenerator { + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(data) + .withPartitionKey(str) + + val putRecordResult = client.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala new file mode 100644 index 000000000000..fdb270eaad8c --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult} +import com.google.common.util.concurrent.{FutureCallback, Futures} + +private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils { + override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { + if (!aggregate) { + new SimpleDataGenerator(kinesisClient) + } else { + new KPLDataGenerator(regionName) + } + } +} + +/** A wrapper for the KinesisProducer provided in the KPL. */ +private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator { + + private lazy val producer: KPLProducer = { + val conf = new KinesisProducerConfiguration() + .setRecordMaxBufferedTime(1000) + .setMaxConnections(1) + .setRegion(regionName) + .setMetricsLevel("none") + + new KPLProducer(conf) + } + + override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { + val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() + data.foreach { num => + val str = num.toString + val data = ByteBuffer.wrap(str.getBytes()) + val future = producer.addUserRecord(streamName, str, data) + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + override def onFailure(t: Throwable): Unit = {} // do nothing + + override def onSuccess(result: UserRecordResult): Unit = { + val shardId = result.getShardId + val seqNumber = result.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + } + Futures.addCallback(future, kinesisCallBack) + } + producer.flushSync() + shardIdToSeqNumbers.toMap + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 52c61dfb1c02..d85b4cda8ce9 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -40,7 +40,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) override def beforeAll(): Unit = { runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index dee30444d8cc..78cec021b78c 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -63,7 +63,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun sc = new SparkContext(conf) runIfTestsEnabled("Prepare KinesisTestUtils") { - testUtils = new KinesisTestUtils() + testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() } } diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d50c6b8d4a42..a2bfd79e1abc 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1458,7 +1458,6 @@ def test_kinesis_stream_api(self): InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, "awsAccessKey", "awsSecretKey") - @unittest.skip("Enable it when we fix SPAKR-12058") def test_kinesis_stream(self): if not are_kinesis_tests_enabled: sys.stderr.write( From d64806b37373c5cc4fd158a9f5005743bd00bf28 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 4 Dec 2015 13:05:07 -0800 Subject: [PATCH 1662/1824] [SPARK-11314][BUILD][HOTFIX] Add exclusion for moved YARN classes. Author: Marcelo Vanzin Closes #10147 from vanzin/SPARK-11314. --- project/MimaExcludes.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d3a3c0ceb68c..b4aa6adc3c62 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -159,7 +159,10 @@ object MimaExcludes { // SPARK-3580 Add getNumPartitions method to JavaRDD ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") - ) + ) ++ + // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a + // private class. + MimaBuild.excludeSparkClass("scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), From b7204e1d41271d2e8443484371770936664350b1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 5 Dec 2015 08:15:30 +0800 Subject: [PATCH 1663/1824] [SPARK-12112][BUILD] Upgrade to SBT 0.13.9 We should upgrade to SBT 0.13.9, since this is a requirement in order to use SBT's new Maven-style resolution features (which will be done in a separate patch, because it's blocked by some binary compatibility issues in the POM reader plugin). I also upgraded Scalastyle to version 0.8.0, which was necessary in order to fix a Scala 2.10.5 compatibility issue (see https://github.com/scalastyle/scalastyle/issues/156). The newer Scalastyle is slightly stricter about whitespace surrounding tokens, so I fixed the new style violations. Author: Josh Rosen Closes #10112 from JoshRosen/upgrade-to-sbt-0.13.9. --- .../org/apache/spark/deploy/JsonProtocol.scala | 2 +- .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 10 +++++----- .../spark/serializer/KryoSerializerSuite.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../streaming/clickstream/PageViewStream.scala | 6 ++++-- .../streaming/flume/sink/SparkSinkSuite.scala | 4 ++-- .../spark/graphx/lib/TriangleCountSuite.scala | 2 +- .../scala/org/apache/spark/ml/param/params.scala | 2 ++ .../mllib/stat/test/StreamingTestMethod.scala | 4 ++-- .../DecisionTreeClassifierSuite.scala | 4 ++-- .../ml/regression/DecisionTreeRegressorSuite.scala | 6 +++--- .../spark/mllib/tree/DecisionTreeSuite.scala | 14 +++++++------- pom.xml | 2 +- project/build.properties | 2 +- project/plugins.sbt | 7 +------ .../spark/sql/catalyst/expressions/CastSuite.scala | 2 +- .../datasources/parquet/ParquetRelation.scala | 8 ++++---- .../sql/execution/columnar/ColumnTypeSuite.scala | 2 +- .../compression/RunLengthEncodingSuite.scala | 4 ++-- .../sql/execution/metric/SQLMetricsSuite.scala | 2 +- 20 files changed, 47 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index ccffb3665298..220b20bf7cbd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -45,7 +45,7 @@ private[deploy] object JsonProtocol { ("id" -> obj.id) ~ ("name" -> obj.desc.name) ~ ("cores" -> obj.desc.maxCores) ~ - ("user" -> obj.desc.user) ~ + ("user" -> obj.desc.user) ~ ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ ("state" -> obj.state.toString) ~ diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 46ed5c04f433..007a71f87cf1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -101,21 +101,21 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) val unionRdd = sc.union(rddWithNoPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[UnionRDD[_]]) } test("SparkContext.union creates PartitionAwareUnionRDD if all RDDs have partitioners") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) val unionRdd = sc.union(rddWithPartitioner, rddWithPartitioner) assert(unionRdd.isInstanceOf[PartitionerAwareUnionRDD[_]]) } test("PartitionAwareUnionRDD raises exception if at least one RDD has no partitioner") { - val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) - val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) intercept[IllegalArgumentException] { new PartitionerAwareUnionRDD(sc, Seq(rddWithNoPartitioner, rddWithPartitioner)) } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index e428414cf6e8..f81fe3113106 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -144,10 +144,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(mutable.Map("one" -> 1, "two" -> 2)) check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one", 2->"two", 3->"three"))) + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) } test("Bug: SPARK-10251") { @@ -174,10 +174,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { check(mutable.Map("one" -> 1, "two" -> 2)) check(mutable.HashMap(1 -> "one", 2 -> "two")) check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(Some(mutable.HashMap(1 -> 1, 2 -> 2)), None, Some(mutable.HashMap(3 -> 4)))) check(List( mutable.HashMap("one" -> 1, "two" -> 2), - mutable.HashMap(1->"one", 2->"two", 3->"three"))) + mutable.HashMap(1 -> "one", 2 -> "two", 3 -> "three"))) } test("ranges") { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index bea7a47cb285..2fcccb22dddf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -51,8 +51,8 @@ object PageView extends Serializable { */ // scalastyle:on object PageViewGenerator { - val pages = Map("http://foo.com/" -> .7, - "http://foo.com/news" -> 0.2, + val pages = Map("http://foo.com/" -> .7, + "http://foo.com/news" -> 0.2, "http://foo.com/contact" -> .1) val httpStatus = Map(200 -> .95, 404 -> .05) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 4ef238606f82..723616817f6a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -86,8 +86,10 @@ object PageViewStream { .map("Unique active users: " + _) // An external dataset we want to join to this stream - val userList = ssc.sparkContext.parallelize( - Map(1 -> "Patrick Wendell", 2 -> "Reynold Xin", 3 -> "Matei Zaharia").toSeq) + val userList = ssc.sparkContext.parallelize(Seq( + 1 -> "Patrick Wendell", + 2 -> "Reynold Xin", + 3 -> "Matei Zaharia")) metric match { case "pageCounts" => pageCounts.print() diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index d2654700ea72..941fde45cd7b 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -36,11 +36,11 @@ import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory // Spark core main, which has too many dependencies to require here manually. // For this reason, we continue to use FunSuite and ignore the scalastyle checks // that fail if this is detected. -//scalastyle:off +// scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { -//scalastyle:on +// scalastyle:on val eventsPerBatch = 1000 val channelCapacity = 5000 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index c47552cf3a3b..608e43cf3ff5 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -26,7 +26,7 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => - val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2) + val rawEdges = sc.parallelize(Array( 0L -> 1L, 1L -> 2L, 2L -> 0L ), 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() val verts = triangleCount.vertices diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d182b0a98896..ee7e89edd879 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -82,7 +82,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali def w(value: T): ParamPair[T] = this -> value /** Creates a param pair with the given value (for Scala). */ + // scalastyle:off def ->(value: T): ParamPair[T] = ParamPair(this, value) + // scalastyle:on /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */ def jsonEncode(value: T): String = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala index a7eaed51b4d5..911b4b923735 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -152,8 +152,8 @@ private[stat] object StudentTTest extends StreamingTestMethod with Logging { private[stat] object StreamingTestMethod { // Note: after new `StreamingTestMethod`s are implemented, please update this map. private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( - "welch"->WelchTTest, - "student"->StudentTTest) + "welch" -> WelchTTest, + "student" -> StudentTTest) def getTestMethodFromName(method: String): StreamingTestMethod = TEST_NAME_TO_OBJECT.get(method) match { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 92b8f84144ab..fda2711fed0f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -73,7 +73,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setMaxDepth(2) .setMaxBins(100) .setSeed(1) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) } @@ -214,7 +214,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setMaxBins(2) .setMaxDepth(2) .setMinInstancesPerNode(2) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) val numClasses = 2 compareAPIs(rdd, dt, categoricalFeatures, numClasses) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index e0d5afa7a7e9..6999a910c34a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -50,7 +50,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setMaxDepth(2) .setMaxBins(100) .setSeed(1) - val categoricalFeatures = Map(0 -> 3, 1-> 3) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } @@ -59,12 +59,12 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } test("copied model must have the same parent") { - val categoricalFeatures = Map(0 -> 2, 1-> 2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) val model = new DecisionTreeRegressor() .setImpurity("variance") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 1a4299db4eab..bf8fe1acac2f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -64,7 +64,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -178,7 +178,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) @@ -237,7 +237,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, numClasses = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) // 2^(10-1) - 1 > 100, so categorical features will be ordered val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) @@ -421,7 +421,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -455,7 +455,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -484,7 +484,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { Variance, maxDepth = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -788,7 +788,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, - maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2), numClasses = 2, minInstancesPerNode = 2) val rootNode = DecisionTree.train(rdd, strategy).topNode diff --git a/pom.xml b/pom.xml index 16e656d11961..ae2ff8878b0a 100644 --- a/pom.xml +++ b/pom.xml @@ -2235,7 +2235,7 @@ org.scalastyle scalastyle-maven-plugin - 0.7.0 + 0.8.0 false true diff --git a/project/build.properties b/project/build.properties index 064ec843da9e..86ca8755820a 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.7 +sbt.version=0.13.9 diff --git a/project/plugins.sbt b/project/plugins.sbt index c06687d8f197..5e23224cf8aa 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -10,14 +10,9 @@ addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") -// For Sonatype publishing -//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) - -//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") - addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.7.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ab77a764483e..a98e16c25321 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -734,7 +734,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val complex = Literal.create( Row( Seq("123", "true", "f"), - Map("a" ->"123", "b" -> "true", "c" -> "f"), + Map("a" -> "123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index fdd745f48e97..bb3e2786978c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -862,9 +862,9 @@ private[sql] object ParquetRelation extends Logging { // The parquet compression short names val shortParquetCompressionCodecNames = Map( - "NONE" -> CompressionCodecName.UNCOMPRESSED, + "NONE" -> CompressionCodecName.UNCOMPRESSED, "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, - "SNAPPY" -> CompressionCodecName.SNAPPY, - "GZIP" -> CompressionCodecName.GZIP, - "LZO" -> CompressionCodecName.LZO) + "SNAPPY" -> CompressionCodecName.SNAPPY, + "GZIP" -> CompressionCodecName.GZIP, + "LZO" -> CompressionCodecName.LZO) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 34dd96929e6c..706ff1f99850 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -35,7 +35,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("defaultSize") { val checks = Map( - NULL-> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, + NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index ce3affba55c7..95642e93ae9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -100,11 +100,11 @@ class RunLengthEncodingSuite extends SparkFunSuite { } test(s"$RunLengthEncoding with $typeName: simple case") { - skeleton(2, Seq(0 -> 2, 1 ->2)) + skeleton(2, Seq(0 -> 2, 1 -> 2)) } test(s"$RunLengthEncoding with $typeName: run length == 1") { - skeleton(2, Seq(0 -> 1, 1 ->1)) + skeleton(2, Seq(0 -> 1, 1 -> 1)) } test(s"$RunLengthEncoding with $typeName: single long run") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4f2cad19bfb6..4339f7260dcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -116,7 +116,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) val df = person.select('name) testSparkPlanMetrics(df, 1, Map( - 0L ->("Project", Map( + 0L -> ("Project", Map( "number of rows" -> 2L))) ) } From bbfc16ec9d690c2dfa20896bd6d33f9783b9c109 Mon Sep 17 00:00:00 2001 From: meiyoula <1039320815@qq.com> Date: Fri, 4 Dec 2015 16:50:40 -0800 Subject: [PATCH 1664/1824] [SPARK-12142][CORE]Reply false when container allocator is not ready and reset target Using Dynamic Allocation function, when a new AM is starting, and ExecutorAllocationManager send RequestExecutor message to AM. If the container allocator is not ready, the whole app will hang on Author: meiyoula <1039320815@qq.com> Closes #10138 from XuTingjun/patch-1. --- .../scala/org/apache/spark/ExecutorAllocationManager.scala | 1 + .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6419218f47c8..34c32ce31296 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -370,6 +370,7 @@ private[spark] class ExecutorAllocationManager( } else { logWarning( s"Unable to reach the cluster manager to request $numExecutorsTarget total executors!") + numExecutorsTarget = oldNumExecutorsTarget 0 } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 13ef4dfd6416..1970f7d150fe 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -600,11 +600,12 @@ private[spark] class ApplicationMaster( localityAwareTasks, hostToLocalTaskCount)) { resetAllocatorInterval() } + context.reply(true) case None => logWarning("Container allocator is not ready to request executors yet.") + context.reply(false) } - context.reply(true) case KillExecutors(executorIds) => logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") From f30373f5ee60f9892c28771e34b208e4f1f675a6 Mon Sep 17 00:00:00 2001 From: rotems Date: Fri, 4 Dec 2015 16:58:31 -0800 Subject: [PATCH 1665/1824] [SPARK-12080][CORE] Kryo - Support multiple user registrators Author: rotems Closes #10078 from Botnaim/KryoMultipleCustomRegistrators. --- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 6 ++++-- docs/configuration.md | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d5ba690ed04b..7b77f78ce6f1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -70,7 +70,9 @@ class KryoSerializer(conf: SparkConf) private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) - private val userRegistrator = conf.getOption("spark.kryo.registrator") + private val userRegistrators = conf.get("spark.kryo.registrator", "") + .split(',') + .filter(!_.isEmpty) private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) @@ -119,7 +121,7 @@ class KryoSerializer(conf: SparkConf) classesToRegister .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } // Allow the user to register their own classes by setting spark.kryo.registrator. - userRegistrator + userRegistrators .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) .foreach { reg => reg.registerClasses(kryo) } // scalastyle:on classforname diff --git a/docs/configuration.md b/docs/configuration.md index c39b4890851b..fd61ddc244f4 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -647,10 +647,10 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.registrator (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. This + If you use Kryo serialization, give a comma-separated list of classes that register your custom classes with Kryo. This property is useful if you need to register your classes in a custom way, e.g. to specify a custom field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be - set to a class that extends + set to classes that extend KryoRegistrator. See the tuning guide for more details. From 3af53e61fd604fe8000e1fdf656d60b79c842d1c Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 4 Dec 2015 17:02:04 -0800 Subject: [PATCH 1666/1824] [SPARK-12084][CORE] Fix codes that uses ByteBuffer.array incorrectly `ByteBuffer` doesn't guarantee all contents in `ByteBuffer.array` are valid. E.g, a ByteBuffer returned by `ByteBuffer.slice`. We should not use the whole content of `ByteBuffer` unless we know that's correct. This patch fixed all places that use `ByteBuffer.array` incorrectly. Author: Shixiong Zhu Closes #10083 from zsxwing/bytebuffer-array. --- .../network/netty/NettyBlockTransferService.scala | 12 +++--------- .../org/apache/spark/scheduler/DAGScheduler.scala | 6 ++++-- .../scala/org/apache/spark/scheduler/Task.scala | 4 ++-- .../spark/serializer/GenericAvroSerializer.scala | 5 ++++- .../apache/spark/serializer/KryoSerializer.scala | 4 ++-- .../spark/storage/TachyonBlockManager.scala | 2 +- .../main/scala/org/apache/spark/util/Utils.scala | 15 ++++++++++++++- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 5 +++-- .../apache/spark/scheduler/TaskContextSuite.scala | 3 ++- .../pythonconverters/AvroConverters.scala | 5 ++++- .../spark/streaming/flume/FlumeInputDStream.scala | 6 +++--- .../streaming/flume/FlumePollingStreamSuite.scala | 4 ++-- .../spark/streaming/flume/FlumeStreamSuite.scala | 4 ++-- .../streaming/kinesis/KinesisStreamSuite.scala | 3 ++- .../parquet/UnsafeRowParquetRecordReader.java | 7 ++++--- .../spark/sql/execution/SparkSqlSerializer.scala | 3 ++- .../columnar/InMemoryColumnarTableScan.scala | 5 ++++- .../parquet/CatalystRowConverter.scala | 8 ++++---- .../scheduler/ReceivedBlockTracker.scala | 5 +++-- .../streaming/util/BatchedWriteAheadLog.scala | 15 ++++++--------- .../util/FileBasedWriteAheadLogWriter.scala | 14 +++----------- .../spark/streaming/JavaWriteAheadLogSuite.java | 15 +++++++-------- 22 files changed, 81 insertions(+), 69 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 82c16e855b0c..40604a4da18d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -30,6 +30,7 @@ import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -123,17 +124,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded // using our binary protocol. - val levelBytes = serializer.newInstance().serialize(level).array() + val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level)) // Convert or copy nio buffer into array in order to serialize it. - val nioBuffer = blockData.nioByteBuffer() - val array = if (nioBuffer.hasArray) { - nioBuffer.array() - } else { - val data = new Array[Byte](nioBuffer.remaining()) - nioBuffer.get(data) - data - } + val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer, new RpcResponseCallback { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e01a9609b9a0..5582720bbcff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics +import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout @@ -997,9 +998,10 @@ class DAGScheduler( // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = stage match { case stage: ShuffleMapStage => - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) } taskBinary = sc.broadcast(taskBinaryBytes) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 2fcd5aa57d11..5fe5ae8c4581 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -191,8 +191,8 @@ private[spark] object Task { // Write the task itself and finish dataOut.flush() - val taskBytes = serializer.serialize(task).array() - out.write(taskBytes) + val taskBytes = serializer.serialize(task) + Utils.writeByteBuffer(taskBytes, out) ByteBuffer.wrap(out.toByteArray) } diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 62f8aae7f212..8d6af9cae892 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -81,7 +81,10 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) * seen values so to limit the number of times that decompression has to be done. */ def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { - val bis = new ByteArrayInputStream(schemaBytes.array()) + val bis = new ByteArrayInputStream( + schemaBytes.array(), + schemaBytes.arrayOffset() + schemaBytes.position(), + schemaBytes.remaining()) val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) new Schema.Parser().parse(new String(bytes, "UTF-8")) }) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 7b77f78ce6f1..62d445f3d7bd 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -309,7 +309,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val kryo = borrowKryo() try { - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { releaseKryo(kryo) @@ -321,7 +321,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ val oldClassLoader = kryo.getClassLoader try { kryo.setClassLoader(loader) - input.setBuffer(bytes.array) + input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining()) kryo.readClassAndObject(input).asInstanceOf[T] } finally { kryo.setClassLoader(oldClassLoader) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 22878783fca6..d14fe4613528 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -103,7 +103,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log val file = getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) try { - os.write(bytes.array()) + Utils.writeByteBuffer(bytes, os) } catch { case NonFatal(e) => logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index af632349c9ca..9dbe66e7eefb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -178,7 +178,20 @@ private[spark] object Utils extends Logging { /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = { + def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + } + + /** + * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]] + */ + def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index a5c583f9f284..8724a3498842 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -41,6 +41,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; @@ -430,7 +431,7 @@ public void randomizedStressTest() { } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); @@ -480,7 +481,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() { } } for (Map.Entry entry : expected.entrySet()) { - final byte[] key = entry.getKey().array(); + final byte[] key = JavaUtils.bufferToArray(entry.getKey()); final byte[] value = entry.getValue(); final BytesToBytesMap.Location loc = map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 450ab7b9fe92..d83d0aee4225 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -23,6 +23,7 @@ import org.mockito.Matchers.any import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import org.apache.spark.metrics.source.JvmSource @@ -57,7 +58,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array) + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) intercept[RuntimeException] { diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 805184e740f0..cf12c98b4af6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -79,7 +79,10 @@ object AvroConversionUtil extends Serializable { def unpackBytes(obj: Any): Array[Byte] = { val bytes: Array[Byte] = obj match { - case buf: java.nio.ByteBuffer => buf.array() + case buf: java.nio.ByteBuffer => + val arr = new Array[Byte](buf.remaining()) + buf.get(arr) + arr case arr: Array[Byte] => arr case other => throw new SparkException( s"Unknown BYTES type ${other.getClass.getName}") diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index c8780aa83bdb..2b9116eb3c79 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -93,9 +93,9 @@ class SparkFlumeEvent() extends Externalizable { /* Serialize to bytes. */ def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - val body = event.getBody.array() - out.writeInt(body.length) - out.write(body) + val body = event.getBody + out.writeInt(body.remaining()) + Utils.writeByteBuffer(body, out) val numHeaders = event.getHeaders.size() out.writeInt(numHeaders) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 5fd2711f5f7d..bb951a6ef100 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -24,11 +24,11 @@ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} @@ -119,7 +119,7 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { case (key, value) => (key.toString, value.toString) }).map(_.asJava) - val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + val bodies = flattenOutputBuffer.map(e => JavaUtils.bytesToString(e.event.getBody)) utils.assertOutput(headers.asJava, bodies.asJava) } } finally { diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index f315e0a7ca23..b29e591c0737 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -22,7 +22,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import com.google.common.base.Charsets import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -31,6 +30,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} @@ -63,7 +63,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w event => event.getHeaders.get("test") should be("header") } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody)) output should be (input) } } finally { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 78cec021b78c..6fe24fe81165 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ @@ -196,7 +197,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun testIfEnabled("custom message handling") { val awsCredentials = KinesisTestUtils.getAWSCredentials() - def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5 + def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, Seconds(10), StorageLevel.MEMORY_ONLY, addFive, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index dade488ca281..0cc4566c9cdd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -332,12 +332,13 @@ private void decodeBinaryBatch(int col, int num) throws IOException { for (int n = 0; n < num; ++n) { if (columnReaders[col].next()) { ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); - int len = bytes.limit() - bytes.position(); + int len = bytes.remaining(); if (originalTypes[col] == OriginalType.UTF8) { - UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len); + UTF8String str = + UTF8String.fromBytes(bytes.array(), bytes.arrayOffset() + bytes.position(), len); rowWriters[n].write(col, str); } else { - rowWriters[n].write(col, bytes.array(), bytes.position(), len); + rowWriters[n].write(col, bytes.array(), bytes.arrayOffset() + bytes.position(), len); } rows[n].setNotNullAt(col); } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 8317f648ccb4..45a8e0324826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -26,6 +26,7 @@ import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Kryo, Serializer} import com.twitter.chill.ResourcePool +import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} import org.apache.spark.sql.types.Decimal import org.apache.spark.util.MutablePair @@ -76,7 +77,7 @@ private[sql] object SparkSqlSerializer { def serialize[T: ClassTag](o: T): Array[Byte] = acquireRelease { k => - k.serialize(o).array() + JavaUtils.bufferToArray(k.serialize(o)) } def deserialize[T: ClassTag](bytes: Array[Byte]): T = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index ce701fb3a7f2..3c5a8cb2aa93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -163,7 +164,9 @@ private[sql] case class InMemoryRelation( .flatMap(_.values)) batchStats += stats - CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats) + CachedBatch(rowCount, columnBuilders.map { builder => + JavaUtils.bufferToArray(builder.build()) + }, stats) } def hasNext: Boolean = rowIterator.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 94298fae2d69..8851bc23cd05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -327,8 +327,8 @@ private[parquet] class CatalystRowConverter( // are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying // it. val buffer = value.toByteBuffer - val offset = buffer.position() - val numBytes = buffer.limit() - buffer.position() + val offset = buffer.arrayOffset() + buffer.position() + val numBytes = buffer.remaining() updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes)) } } @@ -644,8 +644,8 @@ private[parquet] object CatalystRowConverter { // copying it. val buffer = binary.toByteBuffer val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() + val start = buffer.arrayOffset() + buffer.position() + val end = buffer.arrayOffset() + buffer.limit() var unscaled = 0L var i = start diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 500dc70c9850..4dab64d696b3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.JavaUtils import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} @@ -210,9 +211,9 @@ private[streaming] class ReceivedBlockTracker( writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") writeAheadLog.readAll().asScala.foreach { byteBuffer => - logTrace("Recovering record " + byteBuffer) + logInfo("Recovering record " + byteBuffer) Utils.deserialize[ReceivedBlockTrackerLogEvent]( - byteBuffer.array, Thread.currentThread().getContextClassLoader) match { + JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 6e6ed8d81972..7158abc08894 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -28,6 +28,7 @@ import scala.concurrent.duration._ import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils /** @@ -197,17 +198,10 @@ private[util] object BatchedWriteAheadLog { */ case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle]) - /** Copies the byte array of a ByteBuffer. */ - private def getByteArray(buffer: ByteBuffer): Array[Byte] = { - val byteArray = new Array[Byte](buffer.remaining()) - buffer.get(byteArray) - byteArray - } - /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */ def aggregate(records: Seq[Record]): ByteBuffer = { ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]]( - records.map(record => getByteArray(record.data)).toArray)) + records.map(record => JavaUtils.bufferToArray(record.data)).toArray)) } /** @@ -216,10 +210,13 @@ private[util] object BatchedWriteAheadLog { * method therefore needs to be backwards compatible. */ def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = { + val prevPosition = buffer.position() try { - Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap) + Utils.deserialize[Array[Array[Byte]]](JavaUtils.bufferToArray(buffer)).map(ByteBuffer.wrap) } catch { case _: ClassCastException => // users may restart a stream with batching enabled + // Restore `position` so that the user can read `buffer` later + buffer.position(prevPosition) Array(buffer) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index e146bec32a45..1185f30265f6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -24,6 +24,8 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FSDataOutputStream +import org.apache.spark.util.Utils + /** * A writer for writing byte-buffers to a write ahead log file. */ @@ -48,17 +50,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: val lengthToWrite = data.remaining() val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite) stream.writeInt(lengthToWrite) - if (data.hasArray) { - stream.write(data.array()) - } else { - // If the buffer is not backed by an array, we transfer using temp array - // Note that despite the extra array copy, this should be faster than byte-by-byte copy - while (data.hasRemaining) { - val array = new Array[Byte](data.remaining) - data.get(array) - stream.write(array) - } - } + Utils.writeByteBuffer(data, stream: OutputStream) flush() nextOffset = stream.getPos() segment diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index 09b5f8ed0327..f02fa87f6194 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.streaming; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.nio.ByteBuffer; import java.util.Arrays; @@ -27,6 +26,7 @@ import com.google.common.base.Function; import com.google.common.collect.Iterators; import org.apache.spark.SparkConf; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.streaming.util.WriteAheadLog; import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; import org.apache.spark.streaming.util.WriteAheadLogUtils; @@ -112,20 +112,19 @@ public void testCustomWAL() { WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); String data1 = "data1"; - WriteAheadLogRecordHandle handle = - wal.write(ByteBuffer.wrap(data1.getBytes(StandardCharsets.UTF_8)), 1234); + WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234); Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); - Assert.assertEquals(new String(wal.read(handle).array(), StandardCharsets.UTF_8), data1); + Assert.assertEquals(JavaUtils.bytesToString(wal.read(handle)), data1); - wal.write(ByteBuffer.wrap("data2".getBytes(StandardCharsets.UTF_8)), 1235); - wal.write(ByteBuffer.wrap("data3".getBytes(StandardCharsets.UTF_8)), 1236); - wal.write(ByteBuffer.wrap("data4".getBytes(StandardCharsets.UTF_8)), 1237); + wal.write(JavaUtils.stringToBytes("data2"), 1235); + wal.write(JavaUtils.stringToBytes("data3"), 1236); + wal.write(JavaUtils.stringToBytes("data4"), 1237); wal.clean(1236, false); Iterator dataIterator = wal.readAll(); List readData = new ArrayList<>(); while (dataIterator.hasNext()) { - readData.add(new String(dataIterator.next().array(), StandardCharsets.UTF_8)); + readData.add(JavaUtils.bytesToString(dataIterator.next())); } Assert.assertEquals(readData, Arrays.asList("data3", "data4")); } From ee94b70ce56661ea26c5aad17778ade32f3f1d3d Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sat, 5 Dec 2015 15:27:31 +0000 Subject: [PATCH 1667/1824] [SPARK-12096][MLLIB] remove the old constraint in word2vec jira: https://issues.apache.org/jira/browse/SPARK-12096 word2vec now can handle much bigger vocabulary. The old constraint vocabSize.toLong * vectorSize < Ine.max / 8 should be removed. new constraint is vocabSize.toLong * vectorSize < max array length (usually a little less than Int.MaxValue) I tested with vocabsize over 18M and vectorsize = 100. srowen jkbradley Sorry to miss this in last PR. I was reminded today. Author: Yuhao Yang Closes #10103 from hhbyyh/w2vCapacity. --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 655ac0bb5545..be12d4528603 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -306,10 +306,10 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + if (vocabSize.toLong * vectorSize >= Int.MaxValue) { throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + - "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue`.") } val syn0Global = From e9c9ae22b96e08e5bb40a029e84d342efb1aec0c Mon Sep 17 00:00:00 2001 From: Antonio Murgia Date: Sat, 5 Dec 2015 15:42:02 +0000 Subject: [PATCH 1668/1824] [SPARK-11994][MLLIB] Word2VecModel load and save cause SparkException when model is bigger than spark.kryoserializer.buffer.max Author: Antonio Murgia Closes #9989 from tmnd1991/SPARK-11932. --- .../apache/spark/mllib/feature/Word2Vec.scala | 16 ++++++++++++---- .../spark/mllib/feature/Word2VecSuite.scala | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index be12d4528603..b693f3c8e4bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -604,13 +604,21 @@ object Word2VecModel extends Loader[Word2VecModel] { val vectorSize = model.values.head.size val numWords = model.size - val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ - ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) + val metadata = compact(render( + ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ + ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + // We want to partition the model in partitions of size 32MB + val partitionSize = (1L << 25) + // We calculate the approximate size of the model + // We only calculate the array size, not considering + // the string size, the formula is: + // floatSize * numWords * vectorSize + val approxSize = 4L * numWords * vectorSize + val nPartitions = ((approxSize / partitionSize) + 1).toInt val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } - sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) + sc.parallelize(dataArray.toSeq, nPartitions).toDF().write.parquet(Loader.dataPath(path)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index a864eec460f2..37d01e287669 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -92,4 +92,23 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + + test("big model load / save") { + // create a model bigger than 32MB since 9000 * 1000 * 4 > 2^25 + val word2VecMap = Map((0 to 9000).map(i => s"$i" -> Array.fill(1000)(0.1f)): _*) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + } From 7da674851928ed23eb651a3e2f8233e7a684ac41 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 5 Dec 2015 15:52:52 +0000 Subject: [PATCH 1669/1824] [SPARK-11988][ML][MLLIB] Update JPMML to 1.2.7 Update JPMML pmml-model to 1.2.7 Author: Sean Owen Closes #9972 from srowen/SPARK-11988. --- LICENSE | 3 +- mllib/pom.xml | 2 +- .../BinaryClassificationPMMLModelExport.scala | 32 +++++++------- .../GeneralizedLinearPMMLModelExport.scala | 26 +++++------ .../pmml/export/KMeansPMMLModelExport.scala | 44 +++++++++---------- .../mllib/pmml/export/PMMLModelExport.scala | 17 +++---- 6 files changed, 59 insertions(+), 65 deletions(-) diff --git a/LICENSE b/LICENSE index 0db2d14465bd..a2f75b817ab3 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,3 @@ - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ @@ -237,7 +236,7 @@ The following components are provided under a BSD-style license. See project lin The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) + (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) diff --git a/mllib/pom.xml b/mllib/pom.xml index 70139121d8c7..df50aca1a3f7 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -109,7 +109,7 @@ org.jpmml pmml-model - 1.1.15 + 1.2.7 com.sun.xml.fastinfoset diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 622b53a252ac..7abb1bf7ce96 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -45,7 +45,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( val fields = new SArray[FieldName](model.weights.size) val dataDictionary = new DataDictionary val miningSchema = new MiningSchema - val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") + val regressionTableYES = new RegressionTable(model.intercept).setTargetCategory("1") var interceptNO = threshold if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { if (threshold <= 0) { @@ -56,35 +56,35 @@ private[mllib] class BinaryClassificationPMMLModelExport( interceptNO = -math.log(1 / threshold - 1) } } - val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") + val regressionTableNO = new RegressionTable(interceptNO).setTargetCategory("0") val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.CLASSIFICATION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withNormalizationMethod(normalizationMethod) - .withRegressionTables(regressionTableYES, regressionTableNO) + .setFunctionName(MiningFunctionType.CLASSIFICATION) + .setMiningSchema(miningSchema) + .setModelName(description) + .setNormalizationMethod(normalizationMethod) + .addRegressionTables(regressionTableYES, regressionTableNO) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTableYES.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // add target field val targetField = FieldName.create("target") dataDictionary - .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + .addDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala index 1874786af000..4d951d2973a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -45,31 +45,31 @@ private[mllib] class GeneralizedLinearPMMLModelExport( val miningSchema = new MiningSchema val regressionTable = new RegressionTable(model.intercept) val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.REGRESSION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withRegressionTables(regressionTable) + .setFunctionName(MiningFunctionType.REGRESSION) + .setMiningSchema(miningSchema) + .setModelName(description) + .addRegressionTables(regressionTable) for (i <- 0 until model.weights.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + regressionTable.addNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } // for completeness add target field val targetField = FieldName.create("target") - dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) + .addMiningFields(new MiningField(targetField) + .setUsageType(FieldUsageType.TARGET)) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) + pmml.addModels(regressionModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index 069e7afc9fca..b5b824bb9c9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -42,42 +42,42 @@ private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLMode val dataDictionary = new DataDictionary val miningSchema = new MiningSchema val comparisonMeasure = new ComparisonMeasure() - .withKind(ComparisonMeasure.Kind.DISTANCE) - .withMeasure(new SquaredEuclidean()) + .setKind(ComparisonMeasure.Kind.DISTANCE) + .setMeasure(new SquaredEuclidean()) val clusteringModel = new ClusteringModel() - .withModelName("k-means") - .withMiningSchema(miningSchema) - .withComparisonMeasure(comparisonMeasure) - .withFunctionName(MiningFunctionType.CLUSTERING) - .withModelClass(ClusteringModel.ModelClass.CENTER_BASED) - .withNumberOfClusters(model.clusterCenters.length) + .setModelName("k-means") + .setMiningSchema(miningSchema) + .setComparisonMeasure(comparisonMeasure) + .setFunctionName(MiningFunctionType.CLUSTERING) + .setModelClass(ClusteringModel.ModelClass.CENTER_BASED) + .setNumberOfClusters(model.clusterCenters.length) for (i <- 0 until clusterCenter.size) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.addDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - clusteringModel.withClusteringFields( - new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + .addMiningFields(new MiningField(fields(i)) + .setUsageType(FieldUsageType.ACTIVE)) + clusteringModel.addClusteringFields( + new ClusteringField(fields(i)).setCompareFunction(CompareFunctionType.ABS_DIFF)) } - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) - for (i <- 0 until model.clusterCenters.length) { + for (i <- model.clusterCenters.indices) { val cluster = new Cluster() - .withName("cluster_" + i) - .withArray(new org.dmg.pmml.Array() - .withType(Array.Type.REAL) - .withN(clusterCenter.size) - .withValue(model.clusterCenters(i).toArray.mkString(" "))) + .setName("cluster_" + i) + .setArray(new org.dmg.pmml.Array() + .setType(Array.Type.REAL) + .setN(clusterCenter.size) + .setValue(model.clusterCenters(i).toArray.mkString(" "))) // we don't have the size of the single cluster but only the centroids (withValue) // .withSize(value) - clusteringModel.withClusters(cluster) + clusteringModel.addClusters(cluster) } pmml.setDataDictionary(dataDictionary) - pmml.withModels(clusteringModel) + pmml.addModels(clusteringModel) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index 9267e6dbdb85..426bb818c926 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -30,19 +30,14 @@ private[mllib] trait PMMLModelExport { * Holder of the exported model in PMML format */ @BeanProperty - val pmml: PMML = new PMML - - pmml.setVersion("4.2") - setHeader(pmml) - - private def setHeader(pmml: PMML): Unit = { + val pmml: PMML = { val version = getClass.getPackage.getImplementationVersion - val app = new Application().withName("Apache Spark MLlib").withVersion(version) + val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) val header = new Header() - .withApplication(app) - .withTimestamp(timestamp) - pmml.setHeader(header) + .setApplication(app) + .setTimestamp(timestamp) + new PMML("4.2", header, null) } } From c8d0e160dadf3b23c5caa379ba9ad5547794eaa0 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Sat, 5 Dec 2015 15:49:51 -0800 Subject: [PATCH 1670/1824] [SPARK-11774][SPARKR] Implement struct(), encode(), decode() functions in SparkR. Author: Sun Rui Closes #9804 from sun-rui/SPARK-11774. --- R/pkg/NAMESPACE | 3 ++ R/pkg/R/functions.R | 59 ++++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 12 +++++++ R/pkg/inst/tests/test_sparkSQL.R | 37 ++++++++++++++++---- 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 43e5e0119e7f..565a2b1a68b5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -134,8 +134,10 @@ exportMethods("%in%", "datediff", "dayofmonth", "dayofyear", + "decode", "dense_rank", "desc", + "encode", "endsWith", "exp", "explode", @@ -225,6 +227,7 @@ exportMethods("%in%", "stddev", "stddev_pop", "stddev_samp", + "struct", "sqrt", "startsWith", "substr", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index b30331c61c9a..7432cb8e7ccf 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -357,6 +357,40 @@ setMethod("dayofyear", column(jc) }) +#' decode +#' +#' Computes the first argument into a string from a binary using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname decode +#' @name decode +#' @family string_funcs +#' @export +#' @examples \dontrun{decode(df$c, "UTF-8")} +setMethod("decode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "decode", x@jc, charset) + column(jc) + }) + +#' encode +#' +#' Computes the first argument into a binary from a string using the provided character set +#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +#' +#' @rdname encode +#' @name encode +#' @family string_funcs +#' @export +#' @examples \dontrun{encode(df$c, "UTF-8")} +setMethod("encode", + signature(x = "Column", charset = "character"), + function(x, charset) { + jc <- callJStatic("org.apache.spark.sql.functions", "encode", x@jc, charset) + column(jc) + }) + #' exp #' #' Computes the exponential of the given value. @@ -1039,6 +1073,31 @@ setMethod("stddev_samp", column(jc) }) +#' struct +#' +#' Creates a new struct column that composes multiple input columns. +#' +#' @rdname struct +#' @name struct +#' @family normal_funcs +#' @export +#' @examples +#' \dontrun{ +#' struct(df$c, df$d) +#' struct("col1", "col2") +#' } +setMethod("struct", + signature(x = "characterOrColumn"), + function(x, ...) { + if (class(x) == "Column") { + jcols <- lapply(list(x, ...), function(x) { x@jc }) + jc <- callJStatic("org.apache.spark.sql.functions", "struct", jcols) + } else { + jc <- callJStatic("org.apache.spark.sql.functions", "struct", x, list(...)) + } + column(jc) + }) + #' sqrt #' #' Computes the square root of the specified float value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 711ce38f9e10..4b5f786d3946 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -744,10 +744,18 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @export setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) +#' @rdname decode +#' @export +setGeneric("decode", function(x, charset) { standardGeneric("decode") }) + #' @rdname dense_rank #' @export setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) +#' @rdname encode +#' @export +setGeneric("encode", function(x, charset) { standardGeneric("encode") }) + #' @rdname explode #' @export setGeneric("explode", function(x) { standardGeneric("explode") }) @@ -1001,6 +1009,10 @@ setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @export setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) +#' @rdname struct +#' @export +setGeneric("struct", function(x, ...) { standardGeneric("struct") }) + #' @rdname substring_index #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1e7cb5409970..2d26b92ac727 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -27,6 +27,11 @@ checkStructField <- function(actual, expectedName, expectedType, expectedNullabl expect_equal(actual$nullable(), expectedNullable) } +markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() @@ -551,11 +556,6 @@ test_that("collect() and take() on a DataFrame return the same number of rows an }) test_that("collect() support Unicode characters", { - markUtf8 <- function(s) { - Encoding(s) <- "UTF-8" - s - } - lines <- c("{\"name\":\"안녕하세요\"}", "{\"name\":\"您好\", \"age\":30}", "{\"name\":\"ã“ã‚“ã«ã¡ã¯\", \"age\":19}", @@ -933,8 +933,33 @@ test_that("column functions", { # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) + + # Test struct() + df <- createDataFrame(sqlContext, + list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + schema = c("a", "b", "c")) + result <- collect(select(df, struct("a", "c"))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) + expect_equal(result, expected) + + result <- collect(select(df, struct(df$a, df$b))) + expected <- data.frame(row.names = 1:2) + expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) + expect_equal(result, expected) + + # Test encode(), decode() + bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c)) + df <- createDataFrame(sqlContext, + list(list(markUtf8("大åƒä¸–界"), "utf-8", bytes)), + schema = c("a", "b", "c")) + result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) + expect_equal(result[[1]][[1]], bytes) + expect_equal(result[[2]], markUtf8("大åƒä¸–界")) }) -# + test_that("column binary mathfunctions", { lines <- c("{\"a\":1, \"b\":5}", "{\"a\":2, \"b\":6}", From 895b6c474735d7e0a38283f92292daa5c35875ee Mon Sep 17 00:00:00 2001 From: felixcheung Date: Sat, 5 Dec 2015 16:00:12 -0800 Subject: [PATCH 1671/1824] [SPARK-11715][SPARKR] Add R support corr for Column Aggregration Need to match existing method signature Author: felixcheung Closes #9680 from felixcheung/rcorr. --- R/pkg/R/functions.R | 15 +++++++++++++++ R/pkg/R/generics.R | 2 +- R/pkg/R/stats.R | 9 +++++---- R/pkg/inst/tests/test_sparkSQL.R | 2 +- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 7432cb8e7ccf..25231451df3d 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -259,6 +259,21 @@ setMethod("column", function(x) { col(x) }) +#' corr +#' +#' Computes the Pearson Correlation Coefficient for two Columns. +#' +#' @rdname corr +#' @name corr +#' @family math_funcs +#' @export +#' @examples \dontrun{corr(df$c, df$d)} +setMethod("corr", signature(x = "Column"), + function(x, col2) { + stopifnot(class(col2) == "Column") + jc <- callJStatic("org.apache.spark.sql.functions", "corr", x@jc, col2@jc) + column(jc) + }) #' cos #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4b5f786d3946..acfd4841e19a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -411,7 +411,7 @@ setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) #' @rdname statfunctions #' @export -setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) +setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @rdname summary #' @export diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index f79329b11540..d17cce9c756e 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -77,7 +77,7 @@ setMethod("cov", #' Calculates the correlation of two columns of a DataFrame. #' Currently only supports the Pearson Correlation Coefficient. #' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics. -#' +#' #' @param x A SparkSQL DataFrame #' @param col1 the name of the first column #' @param col2 the name of the second column @@ -95,8 +95,9 @@ setMethod("cov", #' corr <- corr(df, "title", "gender", method = "pearson") #' } setMethod("corr", - signature(x = "DataFrame", col1 = "character", col2 = "character"), + signature(x = "DataFrame"), function(x, col1, col2, method = "pearson") { + stopifnot(class(col1) == "character" && class(col2) == "character") statFunctions <- callJMethod(x@sdf, "stat") callJMethod(statFunctions, "corr", col1, col2, method) }) @@ -109,7 +110,7 @@ setMethod("corr", #' #' @param x A SparkSQL DataFrame. #' @param cols A vector column names to search frequent items in. -#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. +#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. #' Should be greater than 1e-4. Default support = 0.01. #' @return a local R data.frame with the frequent items in each column #' @@ -131,7 +132,7 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"), #' sampleBy #' #' Returns a stratified sample without replacement based on the fraction given on each stratum. -#' +#' #' @param x A SparkSQL DataFrame #' @param col column that defines strata #' @param fractions A named list giving sampling fraction for each stratum. If a stratum is diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 2d26b92ac727..a5a234a02d9f 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -892,7 +892,7 @@ test_that("column functions", { c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) c12 <- variance(c) c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) - c14 <- cume_dist() + ntile(1) + c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() # Test if base::rank() is exposed From 6979edf4e1a93caafa8d286692097dd377d7616d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 5 Dec 2015 16:39:01 -0800 Subject: [PATCH 1672/1824] [SPARK-12115][SPARKR] Change numPartitions() to getNumPartitions() to be consistent with Scala/Python Change ```numPartitions()``` to ```getNumPartitions()``` to be consistent with Scala/Python. Note: If we can not catch up with 1.6 release, it will be breaking change for 1.7 that we also need to explain in release note. cc sun-rui felixcheung shivaram Author: Yanbo Liang Closes #10123 from yanboliang/spark-12115. --- R/pkg/R/RDD.R | 55 ++++++++++++++++++++++--------------- R/pkg/R/generics.R | 6 +++- R/pkg/R/pairRDD.R | 4 +-- R/pkg/inst/tests/test_rdd.R | 10 +++---- 4 files changed, 45 insertions(+), 30 deletions(-) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 47945c2825da..00c40c38cabc 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -306,17 +306,28 @@ setMethod("checkpoint", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 2L) -#' numPartitions(rdd) # 2L +#' getNumPartitions(rdd) # 2L #'} -#' @rdname numPartitions +#' @rdname getNumPartitions +#' @aliases getNumPartitions,RDD-method +#' @noRd +setMethod("getNumPartitions", + signature(x = "RDD"), + function(x) { + callJMethod(getJRDD(x), "getNumPartitions") + }) + +#' Gets the number of partitions of an RDD, the same as getNumPartitions. +#' But this function has been deprecated, please use getNumPartitions. +#' +#' @rdname getNumPartitions #' @aliases numPartitions,RDD-method #' @noRd setMethod("numPartitions", signature(x = "RDD"), function(x) { - jrdd <- getJRDD(x) - partitions <- callJMethod(jrdd, "partitions") - callJMethod(partitions, "size") + .Deprecated("getNumPartitions") + getNumPartitions(x) }) #' Collect elements of an RDD @@ -443,7 +454,7 @@ setMethod("countByValue", signature(x = "RDD"), function(x) { ones <- lapply(x, function(item) { list(item, 1L) }) - collect(reduceByKey(ones, `+`, numPartitions(x))) + collect(reduceByKey(ones, `+`, getNumPartitions(x))) }) #' Apply a function to all elements @@ -759,7 +770,7 @@ setMethod("take", resList <- list() index <- -1 jrdd <- getJRDD(x) - numPartitions <- numPartitions(x) + numPartitions <- getNumPartitions(x) serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size @@ -823,7 +834,7 @@ setMethod("first", #' @noRd setMethod("distinct", signature(x = "RDD"), - function(x, numPartitions = SparkR:::numPartitions(x)) { + function(x, numPartitions = SparkR:::getNumPartitions(x)) { identical.mapped <- lapply(x, function(x) { list(x, NULL) }) reduced <- reduceByKey(identical.mapped, function(x, y) { x }, @@ -993,8 +1004,8 @@ setMethod("keyBy", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) -#' numPartitions(rdd) # 4 -#' numPartitions(repartition(rdd, 2L)) # 2 +#' getNumPartitions(rdd) # 4 +#' getNumPartitions(repartition(rdd, 2L)) # 2 #'} #' @rdname repartition #' @aliases repartition,RDD @@ -1014,8 +1025,8 @@ setMethod("repartition", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) -#' numPartitions(rdd) # 3 -#' numPartitions(coalesce(rdd, 1L)) # 1 +#' getNumPartitions(rdd) # 3 +#' getNumPartitions(coalesce(rdd, 1L)) # 1 #'} #' @rdname coalesce #' @aliases coalesce,RDD @@ -1024,7 +1035,7 @@ setMethod("coalesce", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, shuffle = FALSE) { numPartitions <- numToInt(numPartitions) - if (shuffle || numPartitions > SparkR:::numPartitions(x)) { + if (shuffle || numPartitions > SparkR:::getNumPartitions(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed start <- as.integer(base::sample(numPartitions, 1) - 1) @@ -1112,7 +1123,7 @@ setMethod("saveAsTextFile", #' @noRd setMethod("sortBy", signature(x = "RDD", func = "function"), - function(x, func, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, func, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { values(sortByKey(keyBy(x, func), ascending, numPartitions)) }) @@ -1144,7 +1155,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList <- list() index <- -1 jrdd <- getJRDD(newRdd) - numPartitions <- numPartitions(newRdd) + numPartitions <- getNumPartitions(newRdd) serializedModeRDD <- getSerializedMode(newRdd) while (TRUE) { @@ -1368,7 +1379,7 @@ setMethod("setName", setMethod("zipWithUniqueId", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) partitionFunc <- function(partIndex, part) { mapply( @@ -1409,7 +1420,7 @@ setMethod("zipWithUniqueId", setMethod("zipWithIndex", signature(x = "RDD"), function(x) { - n <- numPartitions(x) + n <- getNumPartitions(x) if (n > 1) { nums <- collect(lapplyPartition(x, function(part) { @@ -1521,8 +1532,8 @@ setMethod("unionRDD", setMethod("zipRDD", signature(x = "RDD", other = "RDD"), function(x, other) { - n1 <- numPartitions(x) - n2 <- numPartitions(other) + n1 <- getNumPartitions(x) + n2 <- getNumPartitions(other) if (n1 != n2) { stop("Can only zip RDDs which have the same number of partitions.") } @@ -1588,7 +1599,7 @@ setMethod("cartesian", #' @noRd setMethod("subtract", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { mapFunction <- function(e) { list(e, NA) } rdd1 <- map(x, mapFunction) rdd2 <- map(other, mapFunction) @@ -1620,7 +1631,7 @@ setMethod("subtract", #' @noRd setMethod("intersection", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { rdd1 <- map(x, function(v) { list(v, NA) }) rdd2 <- map(other, function(v) { list(v, NA) }) @@ -1661,7 +1672,7 @@ setMethod("zipPartitions", if (length(rrdds) == 1) { return(rrdds[[1]]) } - nPart <- sapply(rrdds, numPartitions) + nPart <- sapply(rrdds, getNumPartitions) if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index acfd4841e19a..29dd11f41ff5 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -133,7 +133,11 @@ setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @export setGeneric("name", function(x) { standardGeneric("name") }) -# @rdname numPartitions +# @rdname getNumPartitions +# @export +setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) + +# @rdname getNumPartitions # @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 991bea4d2022..334c11d2f89a 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -750,7 +750,7 @@ setMethod("cogroup", #' @noRd setMethod("sortByKey", signature(x = "RDD"), - function(x, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { + function(x, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { rangeBounds <- list() if (numPartitions > 1) { @@ -818,7 +818,7 @@ setMethod("sortByKey", #' @noRd setMethod("subtractByKey", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::numPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { filterFunction <- function(elem) { iters <- elem[[2]] (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 71aed2bb9d6a..7423b4f2bed1 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -28,8 +28,8 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - expect_equal(numPartitions(rdd), 2) - expect_equal(numPartitions(intRdd), 2) + expect_equal(getNumPartitions(rdd), 2) + expect_equal(getNumPartitions(intRdd), 2) }) test_that("first on RDD", { @@ -304,18 +304,18 @@ test_that("repartition/coalesce on RDDs", { # repartition r1 <- repartition(rdd, 2) - expect_equal(numPartitions(r1), 2L) + expect_equal(getNumPartitions(r1), 2L) count <- length(collectPartition(r1, 0L)) expect_true(count >= 8 && count <= 12) r2 <- repartition(rdd, 6) - expect_equal(numPartitions(r2), 6L) + expect_equal(getNumPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) - expect_equal(numPartitions(r3), 1L) + expect_equal(getNumPartitions(r3), 1L) count <- length(collectPartition(r3, 0L)) expect_equal(count, 20) }) From b6e8e63a0dbe471187a146c96fdaddc6b8a8e55e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 5 Dec 2015 22:51:05 -0800 Subject: [PATCH 1673/1824] [SPARK-12044][SPARKR] Fix usage of isnan, isNaN 1, Add ```isNaN``` to ```Column``` for SparkR. ```Column``` should has three related variable functions: ```isNaN, isNull, isNotNull```. 2, Replace ```DataFrame.isNaN``` with ```DataFrame.isnan``` at SparkR side. Because ```DataFrame.isNaN``` has been deprecated and will be removed at Spark 2.0. 3, Add ```isnull``` to ```DataFrame``` for SparkR. ```DataFrame``` should has two related functions: ```isnan, isnull```. cc shivaram sun-rui felixcheung Author: Yanbo Liang Closes #10037 from yanboliang/spark-12044. --- R/pkg/R/column.R | 2 +- R/pkg/R/functions.R | 26 +++++++++++++++++++------- R/pkg/R/generics.R | 8 ++++++-- R/pkg/inst/tests/test_sparkSQL.R | 6 +++++- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 20de3907b7dd..7bb8ef2595b5 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -56,7 +56,7 @@ operators <- list( "&" = "and", "|" = "or", #, "!" = "unary_$bang" "^" = "pow" ) -column_functions1 <- c("asc", "desc", "isNull", "isNotNull") +column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") createOperator <- function(op) { diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 25231451df3d..09e4e04335a3 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -537,19 +537,31 @@ setMethod("initcap", column(jc) }) -#' isNaN +#' is.nan #' -#' Return true iff the column is NaN. +#' Return true if the column is NaN, alias for \link{isnan} #' -#' @rdname isNaN -#' @name isNaN +#' @rdname is.nan +#' @name is.nan #' @family normal_funcs #' @export -#' @examples \dontrun{isNaN(df$c)} -setMethod("isNaN", +#' @examples +#' \dontrun{ +#' is.nan(df$c) +#' isnan(df$c) +#' } +setMethod("is.nan", + signature(x = "Column"), + function(x) { + isnan(x) + }) + +#' @rdname is.nan +#' @name isnan +setMethod("isnan", signature(x = "Column"), function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "isNaN", x@jc) + jc <- callJStatic("org.apache.spark.sql.functions", "isnan", x@jc) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 29dd11f41ff5..c383e6e78b8b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -625,6 +625,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) +#' @rdname column +#' @export +setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) + #' @rdname column #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) @@ -808,9 +812,9 @@ setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @export setGeneric("instr", function(y, x) { standardGeneric("instr") }) -#' @rdname isNaN +#' @rdname is.nan #' @export -setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) +setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname kurtosis #' @export diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index a5a234a02d9f..6ef03ae97635 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -883,7 +883,7 @@ test_that("column functions", { c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) c3 <- cosh(c) + count(c) + crc32(c) + exp(c) c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) - c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) + c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) @@ -894,6 +894,10 @@ test_that("column functions", { c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) c14 <- cume_dist() + ntile(1) + corr(c, c1) c15 <- dense_rank() + percent_rank() + rank() + row_number() + c16 <- is.nan(c) + isnan(c) + isNaN(c) + + # Test if base::is.nan() is exposed + expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) # Test if base::rank() is exposed expect_equal(class(rank())[[1]], "Column") From 04b6799932707f0a4aa4da0f2fc838bdb29794ce Mon Sep 17 00:00:00 2001 From: gcc Date: Sun, 6 Dec 2015 16:27:40 +0000 Subject: [PATCH 1674/1824] [SPARK-12048][SQL] Prevent to close JDBC resources twice Author: gcc Closes #10101 from rh99/master. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index b9dd7f6b4099..1c348ed62fc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -511,6 +511,7 @@ private[sql] class JDBCRDD( } catch { case e: Exception => logWarning("Exception closing connection", e) } + closed = true } override def hasNext: Boolean = { From 49efd03bacad6060d99ed5e2fe53ba3df1d1317e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 6 Dec 2015 11:15:02 -0800 Subject: [PATCH 1675/1824] [SPARK-12138][SQL] Escape \u in the generated comments of codegen When \u appears in a comment block (i.e. in /**/), code gen will break. So, in Expression and CodegenFallback, we escape \u to \\u. yhuai Please review it. I did reproduce it and it works after the fix. Thanks! Author: gatorsmile Closes #10155 from gatorsmile/escapeU. --- .../spark/sql/catalyst/expressions/Expression.scala | 4 +++- .../sql/catalyst/expressions/CodeGenerationSuite.scala | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 614f0c075fd2..6d807c9ecf30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -220,7 +220,9 @@ abstract class Expression extends TreeNode[Expression] { * Returns the string representation of this expression that is safe to be put in * code comments of generated code. */ - protected def toCommentSafeString: String = this.toString.replace("*/", "\\*\\/") + protected def toCommentSafeString: String = this.toString + .replace("*/", "\\*\\/") + .replace("\\u", "\\\\u") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index fe754240dcd6..cd2ef7dcd0cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -107,4 +107,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { true, InternalRow(UTF8String.fromString("*/"))) } + + test("\\u in the data") { + // When \ u appears in a comment block (i.e. in /**/), code gen will break. + // So, in Expression and CodegenFallback, we escape \ u to \\u. + checkEvaluation( + EqualTo(BoundReference(0, StringType, false), Literal.create("\\u", StringType)), + true, + InternalRow(UTF8String.fromString("\\u"))) + } } From 80a824d36eec9d9a9f092ee1741453851218ec73 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 6 Dec 2015 17:35:01 -0800 Subject: [PATCH 1676/1824] [SPARK-12152][PROJECT-INFRA] Speed up Scalastyle checks by only invoking SBT once Currently, `dev/scalastyle` invokes SBT four times, but these invocations can be replaced with a single invocation, saving about one minute of build time. Author: Josh Rosen Closes #10151 from JoshRosen/speed-up-scalastyle. --- dev/scalastyle | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/dev/scalastyle b/dev/scalastyle index ad93f7e85b27..8fd3604b9f45 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,14 +17,17 @@ # limitations under the License. # -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver scalastyle > scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Phive -Phive-thriftserver test:scalastyle >> scalastyle.txt -# Check style with YARN built too -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 scalastyle >> scalastyle.txt -echo -e "q\n" | build/sbt -Pkinesis-asl -Pyarn -Phadoop-2.2 test:scalastyle >> scalastyle.txt - -ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') -rm scalastyle.txt +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pyarn \ + -Phive \ + -Phive-thriftserver \ + scalastyle test:scalastyle \ + | awk '{if($1~/error/)print}' \ +) if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" From 6fd9e70e3ed43836a0685507fff9949f921234f4 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 7 Dec 2015 00:21:55 -0800 Subject: [PATCH 1677/1824] [SPARK-12106][STREAMING][FLAKY-TEST] BatchedWAL test transiently flaky when Jenkins load is high We need to make sure that the last entry is indeed the last entry in the queue. Author: Burak Yavuz Closes #10110 from brkyvz/batch-wal-test-fix. --- .../streaming/util/BatchedWriteAheadLog.scala | 6 ++++-- .../spark/streaming/util/WriteAheadLogSuite.scala | 14 ++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 7158abc08894..b2cd524f28b7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -166,10 +166,12 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp var segment: WriteAheadLogRecordHandle = null if (buffer.length > 0) { logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") + // threads may not be able to add items in order by time + val sortedByTime = buffer.sortBy(_.time) // We take the latest record for the timestamp. Please refer to the class Javadoc for // detailed explanation - val time = buffer.last.time - segment = wrappedLog.write(aggregate(buffer), time) + val time = sortedByTime.last.time + segment = wrappedLog.write(aggregate(sortedByTime), time) } buffer.foreach(_.promise.success(segment)) } catch { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index eaa88ea3cd38..ef1e89df3130 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -480,7 +480,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( p } - test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { + test("BatchedWriteAheadLog - name log with the highest timestamp of aggregated entries") { val blockingWal = new BlockingWriteAheadLog(wal, walHandle) val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) @@ -500,8 +500,14 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // rest of the records will be batched while it takes time for 3 to get written writeAsync(batchedWal, event2, 5L) writeAsync(batchedWal, event3, 8L) - writeAsync(batchedWal, event4, 12L) - writeAsync(batchedWal, event5, 10L) + // we would like event 5 to be written before event 4 in order to test that they get + // sorted before being aggregated + writeAsync(batchedWal, event5, 12L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 3) + } + writeAsync(batchedWal, event4, 10L) eventually(timeout(1 second)) { assert(walBatchingThreadPool.getActiveCount === 5) assert(batchedWal.invokePrivate(queueLength()) === 4) @@ -517,7 +523,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) - verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L)) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(12L)) val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) assert(records.toSet === queuedEvents) } From 9cde7d5fa87e7ddfff0b9c1212920a1d9000539b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 7 Dec 2015 10:34:18 -0800 Subject: [PATCH 1678/1824] [SPARK-12032] [SQL] Re-order inner joins to do join with conditions first Currently, the order of joins is exactly the same as SQL query, some conditions may not pushed down to the correct join, then those join will become cross product and is extremely slow. This patch try to re-order the inner joins (which are common in SQL query), pick the joins that have self-contain conditions first, delay those that does not have conditions. After this patch, the TPCDS query Q64/65 can run hundreds times faster. cc marmbrus nongli Author: Davies Liu Closes #10073 from davies/reorder_joins. --- .../sql/catalyst/optimizer/Optimizer.scala | 56 ++++++++++- .../sql/catalyst/planning/patterns.scala | 40 +++++++- .../catalyst/optimizer/JoinOrderSuite.scala | 95 +++++++++++++++++++ 3 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 06d14fcf8b9c..f6088695a927 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,14 +18,12 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet + import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.FullOuter -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ @@ -44,6 +42,7 @@ object DefaultOptimizer extends Optimizer { // Operator push down SetOperationPushDown, SamplePushDown, + ReorderJoin, PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, @@ -711,6 +710,53 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel } } +/** + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. + */ +object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to join. + * @param conditions a list of condition for join. + */ + def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + Join(input(0), input(1), Inner, conditions.reduceLeftOption(And)) + } else { + val left :: rest = input.toList + // find out the first join that have at least one join condition + val conditionalJoin = rest.find { plan => + val refs = left.outputSet ++ plan.outputSet + conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) + .exists(_.references.subsetOf(refs)) + } + // pick the next one if no condition left + val right = conditionalJoin.getOrElse(rest.head) + + val joinedRefs = left.outputSet ++ right.outputSet + val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs)) + val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) + + // should not have reference to same logical plan + createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ ExtractFiltersAndInnerJoins(input, conditions) + if input.size > 2 && conditions.nonEmpty => + createOrderedJoin(input, conditions) + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 6f4f11406d7c..cd3f15cbe107 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNodeRef /** * A pattern that matches any number of project or filter operations on top of another relational @@ -132,6 +131,45 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } } +/** + * A pattern that collects the filter and inner joins. + * + * Filter + * | + * inner Join + * / \ ----> (Seq(plan0, plan1, plan2), conditions) + * Filter plan2 + * | + * inner join + * / \ + * plan0 plan1 + * + * Note: This pattern currently only works for left-deep trees. + */ +object ExtractFiltersAndInnerJoins extends PredicateHelper { + + // flatten all inner joins, which are next to each other + def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match { + case Join(left, right, Inner, cond) => + val (plans, conditions) = flattenJoin(left) + (plans ++ Seq(right), conditions ++ cond.toSeq) + + case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + + case _ => (Seq(plan), Seq()) + } + + def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match { + case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) => + Some(flattenJoin(f)) + case j @ Join(_, _, Inner, _) => + Some(flattenJoin(j)) + case _ => None + } +} + /** * A pattern that collects all adjacent unions and returns their children as a Seq. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala new file mode 100644 index 000000000000..9b1e16c72764 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class JoinOrderSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Filter Pushdown", Once, + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + ReorderJoin, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + PushPredicateThroughAggregate, + ColumnPruning, + ProjectCollapsing) :: Nil + + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int) + + test("extract filters and joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) { + assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected) + } + + testExtract(x, None) + testExtract(x.where("x.b".attr === 1), None) + testExtract(x.join(y), Some(Seq(x, y), Seq())) + testExtract(x.join(y, condition = Some("x.b".attr === "y.d".attr)), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).where("x.b".attr === "y.d".attr), + Some(Seq(x, y), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(z), Some(Seq(x, y, z), Seq())) + testExtract(x.join(y).where("x.b".attr === "y.d".attr).join(z), + Some(Seq(x, y, z), Seq("x.b".attr === "y.d".attr))) + testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq())) + testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr), + Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr))) + } + + test("reorder inner joins") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + + val originalQuery = { + x.join(y).join(z) + .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.join(z, condition = Some("x.b".attr === "z.b".attr)) + .join(y, condition = Some("y.d".attr === "z.a".attr)) + .analyze + + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + } +} From 39d677c8f1ee7ebd7e142bec0415cf8f90ac84b6 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Mon, 7 Dec 2015 10:38:17 -0800 Subject: [PATCH 1679/1824] [SPARK-12034][SPARKR] Eliminate warnings in SparkR test cases. This PR: 1. Suppress all known warnings. 2. Cleanup test cases and fix some errors in test cases. 3. Fix errors in HiveContext related test cases. These test cases are actually not run previously due to a bug of creating TestHiveContext. 4. Support 'testthat' package version 0.11.0 which prefers that test cases be under 'tests/testthat' 5. Make sure the default Hadoop file system is local when running test cases. 6. Turn on warnings into errors. Author: Sun Rui Closes #10030 from sun-rui/SPARK-12034. --- R/pkg/inst/tests/{ => testthat}/jarTest.R | 0 .../tests/{ => testthat}/packageInAJarTest.R | 0 R/pkg/inst/tests/{ => testthat}/test_Serde.R | 0 .../tests/{ => testthat}/test_binaryFile.R | 0 .../{ => testthat}/test_binary_function.R | 0 .../tests/{ => testthat}/test_broadcast.R | 0 R/pkg/inst/tests/{ => testthat}/test_client.R | 0 .../inst/tests/{ => testthat}/test_context.R | 0 .../tests/{ => testthat}/test_includeJAR.R | 2 +- .../{ => testthat}/test_includePackage.R | 0 R/pkg/inst/tests/{ => testthat}/test_mllib.R | 14 ++-- .../{ => testthat}/test_parallelize_collect.R | 0 R/pkg/inst/tests/{ => testthat}/test_rdd.R | 0 .../inst/tests/{ => testthat}/test_shuffle.R | 0 .../inst/tests/{ => testthat}/test_sparkSQL.R | 68 +++++++++++-------- R/pkg/inst/tests/{ => testthat}/test_take.R | 0 .../inst/tests/{ => testthat}/test_textFile.R | 0 R/pkg/inst/tests/{ => testthat}/test_utils.R | 0 R/pkg/tests/run-all.R | 3 + R/run-tests.sh | 2 +- 20 files changed, 50 insertions(+), 39 deletions(-) rename R/pkg/inst/tests/{ => testthat}/jarTest.R (100%) rename R/pkg/inst/tests/{ => testthat}/packageInAJarTest.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_Serde.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_binaryFile.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_binary_function.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_broadcast.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_client.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_context.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_includeJAR.R (94%) rename R/pkg/inst/tests/{ => testthat}/test_includePackage.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_mllib.R (90%) rename R/pkg/inst/tests/{ => testthat}/test_parallelize_collect.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_rdd.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_shuffle.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_sparkSQL.R (97%) rename R/pkg/inst/tests/{ => testthat}/test_take.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_textFile.R (100%) rename R/pkg/inst/tests/{ => testthat}/test_utils.R (100%) diff --git a/R/pkg/inst/tests/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R similarity index 100% rename from R/pkg/inst/tests/jarTest.R rename to R/pkg/inst/tests/testthat/jarTest.R diff --git a/R/pkg/inst/tests/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R similarity index 100% rename from R/pkg/inst/tests/packageInAJarTest.R rename to R/pkg/inst/tests/testthat/packageInAJarTest.R diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R similarity index 100% rename from R/pkg/inst/tests/test_Serde.R rename to R/pkg/inst/tests/testthat/test_Serde.R diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R similarity index 100% rename from R/pkg/inst/tests/test_binaryFile.R rename to R/pkg/inst/tests/testthat/test_binaryFile.R diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R similarity index 100% rename from R/pkg/inst/tests/test_binary_function.R rename to R/pkg/inst/tests/testthat/test_binary_function.R diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R similarity index 100% rename from R/pkg/inst/tests/test_broadcast.R rename to R/pkg/inst/tests/testthat/test_broadcast.R diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/testthat/test_client.R similarity index 100% rename from R/pkg/inst/tests/test_client.R rename to R/pkg/inst/tests/testthat/test_client.R diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/testthat/test_context.R similarity index 100% rename from R/pkg/inst/tests/test_context.R rename to R/pkg/inst/tests/testthat/test_context.R diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/testthat/test_includeJAR.R similarity index 94% rename from R/pkg/inst/tests/test_includeJAR.R rename to R/pkg/inst/tests/testthat/test_includeJAR.R index cc1faeabffe3..f89aa8e507fd 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/testthat/test_includeJAR.R @@ -20,7 +20,7 @@ runScript <- function() { sparkHome <- Sys.getenv("SPARK_HOME") sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) - scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/jarTest.R") + scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/testthat/jarTest.R") submitPath <- file.path(sparkHome, "bin/spark-submit") res <- system2(command = submitPath, args = c(jarPath, scriptPath), diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R similarity index 100% rename from R/pkg/inst/tests/test_includePackage.R rename to R/pkg/inst/tests/testthat/test_includePackage.R diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R similarity index 90% rename from R/pkg/inst/tests/test_mllib.R rename to R/pkg/inst/tests/testthat/test_mllib.R index e0667e5e22c1..08099dd96a87 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -26,7 +26,7 @@ sc <- sparkR.init() sqlContext <- sparkRSQL.init(sc) test_that("glm and predict", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) test <- select(training, "Sepal_Length") model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") prediction <- predict(model, test) @@ -39,7 +39,7 @@ test_that("glm and predict", { }) test_that("glm should work with long formula", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) training$LongLongLongLongLongName <- training$Sepal_Width training$VeryLongLongLongLonLongName <- training$Sepal_Length training$AnotherLongLongLongLongName <- training$Species @@ -51,7 +51,7 @@ test_that("glm should work with long formula", { }) test_that("predictions match with native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) @@ -59,7 +59,7 @@ test_that("predictions match with native glm", { }) test_that("dot minus and intercept vs native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) model <- glm(Sepal_Width ~ . - Species + 0, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) @@ -67,7 +67,7 @@ test_that("dot minus and intercept vs native glm", { }) test_that("feature interaction vs native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) vals <- collect(select(predict(model, training), "prediction")) rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) @@ -75,7 +75,7 @@ test_that("feature interaction vs native glm", { }) test_that("summary coefficients match with native glm", { - training <- createDataFrame(sqlContext, iris) + training <- suppressWarnings(createDataFrame(sqlContext, iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) coefs <- unlist(stats$coefficients) devianceResiduals <- unlist(stats$devianceResiduals) @@ -92,7 +92,7 @@ test_that("summary coefficients match with native glm", { }) test_that("summary coefficients match with native glm of family 'binomial'", { - df <- createDataFrame(sqlContext, iris) + df <- suppressWarnings(createDataFrame(sqlContext, iris)) training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R similarity index 100% rename from R/pkg/inst/tests/test_parallelize_collect.R rename to R/pkg/inst/tests/testthat/test_parallelize_collect.R diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R similarity index 100% rename from R/pkg/inst/tests/test_rdd.R rename to R/pkg/inst/tests/testthat/test_rdd.R diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R similarity index 100% rename from R/pkg/inst/tests/test_shuffle.R rename to R/pkg/inst/tests/testthat/test_shuffle.R diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R similarity index 97% rename from R/pkg/inst/tests/test_sparkSQL.R rename to R/pkg/inst/tests/testthat/test_sparkSQL.R index 6ef03ae97635..39fc94aea5fb 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -133,38 +133,45 @@ test_that("create DataFrame from RDD", { expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - df <- jsonFile(sqlContext, jsonPathNa) - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") - insertInto(df, "people") - expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) - expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) - schema <- structType(structField("name", "string"), structField("age", "integer"), structField("height", "float")) - df2 <- createDataFrame(sqlContext, df.toRDD, schema) - df2AsDF <- as.DataFrame(sqlContext, df.toRDD, schema) + df <- read.df(sqlContext, jsonPathNa, "json", schema) + df2 <- createDataFrame(sqlContext, toRDD(df), schema) + df2AsDF <- as.DataFrame(sqlContext, toRDD(df), schema) expect_equal(columns(df2), c("name", "age", "height")) expect_equal(columns(df2AsDF), c("name", "age", "height")) expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) expect_equal(dtypes(df2AsDF), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) - expect_equal(collect(where(df2AsDF, df2$name == "Bob")), c("Bob", 16, 176.5)) + expect_equal(as.list(collect(where(df2, df2$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) + expect_equal(as.list(collect(where(df2AsDF, df2AsDF$name == "Bob"))), + list(name = "Bob", age = 16, height = 176.5)) localDF <- data.frame(name=c("John", "Smith", "Sarah"), - age=c(19, 23, 18), - height=c(164.10, 181.4, 173.7)) + age=c(19L, 23L, 18L), + height=c(176.5, 181.4, 173.7)) df <- createDataFrame(sqlContext, localDF, schema) expect_is(df, "DataFrame") expect_equal(count(df), 3) expect_equal(columns(df), c("name", "age", "height")) expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) - expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) + expect_equal(as.list(collect(where(df, df$name == "John"))), + list(name = "John", age = 19L, height = 176.5)) + + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + df <- read.df(hiveCtx, jsonPathNa, "json", schema) + invisible(insertInto(df, "people")) + expect_equal(collect(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"))$age, + c(16)) + expect_equal(collect(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"))$height, + c(176.5)) }) test_that("convert NAs to null type in DataFrames", { @@ -250,7 +257,7 @@ test_that("create DataFrame from list or data.frame", { ldf2 <- collect(df) expect_equal(ldf$a, ldf2$a) - irisdf <- createDataFrame(sqlContext, iris) + irisdf <- suppressWarnings(createDataFrame(sqlContext, iris)) iris_collected <- collect(irisdf) expect_equivalent(iris_collected[,-5], iris[,-5]) expect_equal(iris_collected$Species, as.character(iris$Species)) @@ -463,7 +470,7 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) expect_is(unioned, "RDD") - expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(getSerializedMode(unioned), "byte") expect_equal(collect(unioned)[[2]]$name, "Andy") }) @@ -485,13 +492,13 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { unionByte <- unionRDD(rdd, dfRDD) expect_is(unionByte, "RDD") - expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(getSerializedMode(unionByte), "byte") expect_equal(collect(unionByte)[[1]], 1) expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) expect_is(unionString, "RDD") - expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(getSerializedMode(unionString), "byte") expect_equal(collect(unionString)[[1]], "Michael") expect_equal(collect(unionString)[[5]]$name, "Andy") }) @@ -504,7 +511,7 @@ test_that("objectFile() works with row serialization", { objectIn <- objectFile(sc, objectPath) expect_is(objectIn, "RDD") - expect_equal(SparkR:::getSerializedMode(objectIn), "byte") + expect_equal(getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) @@ -849,6 +856,7 @@ test_that("write.df() as parquet file", { }) test_that("test HiveContext", { + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) }, @@ -863,10 +871,10 @@ test_that("test HiveContext", { expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") - saveAsTable(df, "json", "json", "append", path = jsonPath2) - df3 <- sql(hiveCtx, "select * from json") + invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) + df3 <- sql(hiveCtx, "select * from json2") expect_is(df3, "DataFrame") - expect_equal(count(df3), 6) + expect_equal(count(df3), 3) }) test_that("column operators", { @@ -1311,7 +1319,7 @@ test_that("toJSON() returns an RDD of the correct values", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) expect_is(testRDD, "RDD") - expect_equal(SparkR:::getSerializedMode(testRDD), "string") + expect_equal(getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) @@ -1641,7 +1649,7 @@ test_that("SQL error message is returned from JVM", { expect_equal(grepl("Table not found: blah", retError), TRUE) }) -irisDF <- createDataFrame(sqlContext, iris) +irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) test_that("Method as.data.frame as a synonym for collect()", { expect_equal(as.data.frame(irisDF), collect(irisDF)) @@ -1670,7 +1678,7 @@ test_that("attach() on a DataFrame", { }) test_that("with() on a DataFrame", { - df <- createDataFrame(sqlContext, iris) + df <- suppressWarnings(createDataFrame(sqlContext, iris)) expect_error(Sepal_Length) sum1 <- with(df, list(summary(Sepal_Length), summary(Sepal_Width))) expect_equal(collect(sum1[[1]])[1, "Sepal_Length"], "150") diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/testthat/test_take.R similarity index 100% rename from R/pkg/inst/tests/test_take.R rename to R/pkg/inst/tests/testthat/test_take.R diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R similarity index 100% rename from R/pkg/inst/tests/test_textFile.R rename to R/pkg/inst/tests/testthat/test_textFile.R diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R similarity index 100% rename from R/pkg/inst/tests/test_utils.R rename to R/pkg/inst/tests/testthat/test_utils.R diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 4f8a1ed2d83e..1d04656ac259 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -18,4 +18,7 @@ library(testthat) library(SparkR) +# Turn all warnings into errors +options("warn" = 2) + test_package("SparkR") diff --git a/R/run-tests.sh b/R/run-tests.sh index e82ad0ba2cd0..e64a4ea94c58 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then From ef3f047c07ef0ac4a3a97e6bc11e1c28c6c8f9a0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 7 Dec 2015 11:00:25 -0800 Subject: [PATCH 1680/1824] [SPARK-12132] [PYSPARK] raise KeyboardInterrupt inside SIGINT handler Currently, the current line is not cleared by Cltr-C After this patch ``` >>> asdfasdf^C Traceback (most recent call last): File "~/spark/python/pyspark/context.py", line 225, in signal_handler raise KeyboardInterrupt() KeyboardInterrupt ``` It's still worse than 1.5 (and before). Author: Davies Liu Closes #10134 from davies/fix_cltrc. --- python/pyspark/context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 77710a13394c..529d16b48039 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -222,6 +222,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # create a signal handler which would be invoked on receiving SIGINT def signal_handler(signal, frame): self.cancelAllJobs() + raise KeyboardInterrupt() # see http://stackoverflow.com/questions/23206787/ if isinstance(threading.current_thread(), threading._MainThread): From 5d80d8c6a54b2113022eff31187e6d97521bd2cf Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Dec 2015 11:03:59 -0800 Subject: [PATCH 1681/1824] [SPARK-11932][STREAMING] Partition previous TrackStateRDD if partitioner not present The reason is that TrackStateRDDs generated by trackStateByKey expect the previous batch's TrackStateRDDs to have a partitioner. However, when recovery from DStream checkpoints, the RDDs recovered from RDD checkpoints do not have a partitioner attached to it. This is because RDD checkpoints do not preserve the partitioner (SPARK-12004). While #9983 solves SPARK-12004 by preserving the partitioner through RDD checkpoints, there may be a non-zero chance that the saving and recovery fails. To be resilient, this PR repartitions the previous state RDD if the partitioner is not detected. Author: Tathagata Das Closes #9988 from tdas/SPARK-11932. --- .../apache/spark/streaming/Checkpoint.scala | 2 +- .../streaming/dstream/TrackStateDStream.scala | 39 ++-- .../spark/streaming/rdd/TrackStateRDD.scala | 29 ++- .../spark/streaming/CheckpointSuite.scala | 189 +++++++++++++----- .../spark/streaming/TestSuiteBase.scala | 6 + .../streaming/TrackStateByKeySuite.scala | 77 +++++-- 6 files changed, 258 insertions(+), 84 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index fd0e8d5d690b..d0046afdeb44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -277,7 +277,7 @@ class CheckpointWriter( val bytes = Checkpoint.serialize(checkpoint, conf) executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) - logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") + logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 0ada1111ce30..ea6213420e7a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD - val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { - TrackStateRDD.createFromPairRDD[K, V, S, E]( - spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), - partitioner, validTime - ) + val prevStateRDD = getOrCompute(validTime - slideDuration) match { + case Some(rdd) => + if (rdd.partitioner != Some(partitioner)) { + // If the RDD is not partitioned the right way, let us repartition it using the + // partition index as the key. This is to ensure that state RDD is always partitioned + // before creating another state RDD using it + TrackStateRDD.createFromRDD[K, V, S, E]( + rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) + } else { + rdd + } + case None => + TrackStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, + validTime + ) } + // Compute the new state RDD with previous state RDD and partitioned data RDD - parent.getOrCompute(validTime).map { dataRDD => - val partitionedDataRDD = dataRDD.partitionBy(partitioner) - val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => - (validTime - interval).milliseconds - } - new TrackStateRDD( - prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime) + // Even if there is no data RDD, use an empty one to create a new state RDD + val dataRDD = parent.getOrCompute(validTime).getOrElse { + context.sparkContext.emptyRDD[(K, V)] + } + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds } + Some(new TrackStateRDD( + prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index 7050378d0feb..30aafcf1460e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: private[streaming] object TrackStateRDD { - def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( pairRDD: RDD[(K, S)], partitioner: Partitioner, - updateTime: Time): TrackStateRDD[K, V, S, T] = { + updateTime: Time): TrackStateRDD[K, V, S, E] = { val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } - Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None - new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + } + + def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + rdd: RDD[(K, S, Long)], + partitioner: Partitioner, + updateTime: Time): TrackStateRDD[K, V, S, E] = { + + val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) } + val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, (state, updateTime)) => + stateMap.put(key, state, updateTime) + } + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index b1cbc7163bee..cd28d3cf408d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -33,17 +33,149 @@ import org.mockito.Mockito.mock import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.TestUtils +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} +/** + * A trait of that can be mixed in to get methods for testing DStream operations under + * DStream checkpointing. Note that the implementations of this trait has to implement + * the `setupCheckpointOperation` + */ +trait DStreamCheckpointTester { self: SparkFunSuite => + + /** + * Tests a streaming operation under checkpointing, by restarting the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. + */ + protected def testCheckpointedOperation[U: ClassTag, V: ClassTag]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatchesBeforeRestart: Int, + batchDuration: Duration = Milliseconds(500), + stopSparkContextAfterTest: Boolean = true + ) { + require(numBatchesBeforeRestart < expectedOutput.size, + "Number of batches before context restart less than number of expected output " + + "(i.e. number of total batches to run)") + require(StreamingContext.getActive().isEmpty, + "Cannot run test with already active streaming context") + + // Current code assumes that number of batches to be run = number of inputs + val totalNumBatches = input.size + val batchDurationMillis = batchDuration.milliseconds + + // Setup the stream computation + val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + logDebug(s"Using checkpoint directory $checkpointDir") + val ssc = createContextForCheckpointOperation(batchDuration) + require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, + "Cannot run test without manual clock in the conf") + + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val operatedStream = operation(inputStream) + operatedStream.print() + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) + outputStream.register() + ssc.checkpoint(checkpointDir) + + // Do the computation for initial number of batches, create checkpoint file and quit + val beforeRestartOutput = generateOutput[V](ssc, + Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest) + assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true) + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + + val restartedSsc = new StreamingContext(checkpointDir) + val afterRestartOutput = generateOutput[V](restartedSsc, + Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest) + assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false) + } + + protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = { + val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) + } + + private def generateOutput[V: ClassTag]( + ssc: StreamingContext, + targetBatchTime: Time, + checkpointDir: String, + stopSparkContext: Boolean + ): Seq[Seq[V]] = { + try { + val batchDuration = ssc.graph.batchDuration + val batchCounter = new BatchCounter(ssc) + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val currentTime = clock.getTimeMillis() + + logInfo("Manual clock before advancing = " + clock.getTimeMillis()) + clock.setTime(targetBatchTime.milliseconds) + logInfo("Manual clock after advancing = " + clock.getTimeMillis()) + + val outputStream = ssc.graph.getOutputStreams().filter { dstream => + dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] + }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + + eventually(timeout(10 seconds)) { + ssc.awaitTerminationOrTimeout(10) + assert(batchCounter.getLastCompletedBatchTime === targetBatchTime) + } + + eventually(timeout(10 seconds)) { + val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter { + _.toString.contains(clock.getTimeMillis.toString) + } + // Checkpoint files are written twice for every batch interval. So assert that both + // are written to make sure that both of them have been written. + assert(checkpointFilesOfLatestTime.size === 2) + } + outputStream.output.map(_.flatten) + + } finally { + ssc.stop(stopSparkContext = stopSparkContext) + } + } + + private def assertOutput[V: ClassTag]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + beforeRestart: Boolean): Unit = { + val expectedPartialOutput = if (beforeRestart) { + expectedOutput.take(output.size) + } else { + expectedOutput.takeRight(output.size) + } + val setComparison = output.zip(expectedPartialOutput).forall { + case (o, e) => o.toSet === e.toSet + } + assert(setComparison, s"set comparison failed\n" + + s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" + + s"Generated output items: ${output.mkString("\n")}" + ) + } +} + /** * This test suites tests the checkpointing functionality of DStreams - * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { var ssc: StreamingContext = null @@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase { override def afterFunction() { super.afterFunction() - if (ssc != null) ssc.stop() + if (ssc != null) { ssc.stop() } Utils.deleteRecursively(new File(checkpointDir)) } @@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase { Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), - Seq(("", 2)), Seq() ), + Seq(("", 2)), + Seq() + ), 3 ) } @@ -634,53 +768,6 @@ class CheckpointSuite extends TestSuiteBase { checkpointWriter.stop() } - /** - * Tests a streaming operation under checkpointing, by restarting the operation - * from checkpoint file and verifying whether the final output is correct. - * The output is assumed to have come from a reliable queue which an replay - * data as required. - * - * NOTE: This takes into consideration that the last batch processed before - * master failure will be re-processed after restart/recovery. - */ - def testCheckpointedOperation[U: ClassTag, V: ClassTag]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { - - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 - // because the last batch will be processed again - - // Do the computation for initial number of batches, create checkpoint file and quit - ssc = setupStreams[U, V](input, operation) - ssc.start() - val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) - ssc.stop() - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) - - // Restart and complete the computation from checkpoint file - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation " + - "\n-------------------------------------------\n" - ) - ssc = new StreamingContext(checkpointDir) - ssc.start() - val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) - // the first element will be re-processed data of the last batch before restart - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) - ssc.stop() - ssc = null - } - /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index a45c92d9c7bc..be0f4636a6cb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) { // All access to this state should be guarded by `BatchCounter.this.synchronized` private var numCompletedBatches = 0 private var numStartedBatches = 0 + private var lastCompletedBatchTime: Time = null private val listener = new StreamingListener { override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = @@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) { override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = BatchCounter.this.synchronized { numCompletedBatches += 1 + lastCompletedBatchTime = batchCompleted.batchInfo.batchTime BatchCounter.this.notifyAll() } } @@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) { numStartedBatches } + def getLastCompletedBatchTime: Time = this.synchronized { + lastCompletedBatchTime + } + /** * Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if * `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index 58aef74c0040..1fc320d31b18 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -25,31 +25,27 @@ import scala.reflect.ClassTag import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { +class TrackStateByKeySuite extends SparkFunSuite + with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { private var sc: SparkContext = null - private var ssc: StreamingContext = null - private var checkpointDir: File = null - private val batchDuration = Seconds(1) + protected var checkpointDir: File = null + protected val batchDuration = Seconds(1) before { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) - } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } checkpointDir = Utils.createTempDir("checkpoint") - - ssc = new StreamingContext(sc, batchDuration) - ssc.checkpoint(checkpointDir.toString) } after { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) + if (checkpointDir != null) { + Utils.deleteRecursively(checkpointDir) } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } } override def beforeAll(): Unit = { @@ -242,7 +238,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(dstreamImpl.stateClass === classOf[Double]) assert(dstreamImpl.emittedClass === classOf[Long]) } - + val ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types @@ -451,8 +447,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef expectedCheckpointDuration: Duration, explicitCheckpointDuration: Option[Duration] = None ): Unit = { + val ssc = new StreamingContext(sc, batchDuration) + try { - ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) val dummyFunc = (value: Option[Int], state: State[Int]) => 0 val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) @@ -462,11 +459,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef trackStateStream.checkpoint(d) } trackStateStream.register() + ssc.checkpoint(checkpointDir.toString) ssc.start() // should initialize all the checkpoint durations assert(trackStateStream.checkpointDuration === null) assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) } finally { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + ssc.stop(stopSparkContext = false) } } @@ -479,6 +477,50 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) } + + test("trackStateByKey - driver failure recovery") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + def operation(dstream: DStream[String]): DStream[(String, Int)] = { + + val checkpointDuration = batchDuration * (stateData.size / 2) + + val runningCount = (value: Option[Int], state: State[Int]) => { + state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) + state.get() + } + + val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey( + StateSpec.function(runningCount)) + // Set internval make sure there is one RDD checkpointing + trackStateStream.checkpoint(checkpointDuration) + trackStateStream.stateSnapshots() + } + + testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, + batchDuration = batchDuration, stopSparkContextAfterTest = false) + } + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], trackStateSpec: StateSpec[K, Int, S, T], @@ -500,6 +542,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { // Setup the stream computation + val ssc = new StreamingContext(sc, Seconds(1)) val inputStream = new TestInputStream(ssc, input, numPartitions = 2) val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] @@ -511,12 +554,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef stateSnapshotStream.register() val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.toString) ssc.start() val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] clock.advance(batchDuration.milliseconds * numBatches) batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + ssc.stop(stopSparkContext = false) (collectedOutputs, collectedStateSnapshots) } From 3f4efb5c23b029496b112760fa062ff070c20334 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 7 Dec 2015 12:01:09 -0800 Subject: [PATCH 1682/1824] [SPARK-12060][CORE] Avoid memory copy in JavaSerializerInstance.serialize Merged #10051 again since #10083 is resolved. This reverts commit 328b757d5d4486ea3c2e246780792d7a57ee85e5. Author: Shixiong Zhu Closes #10167 from zsxwing/merge-SPARK-12060. --- .../spark/serializer/JavaSerializer.scala | 7 ++--- .../spark/util/ByteBufferOutputStream.scala | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index b463a71d5bd7..ea718a0edbe7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -24,8 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} private[spark] class JavaSerializationStream( out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) @@ -96,11 +95,11 @@ private[spark] class JavaSerializerInstance( extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() + val bos = new ByteBufferOutputStream() val out = serializeStream(bos) out.writeObject(t) out.close() - ByteBuffer.wrap(bos.toByteArray) + bos.toByteBuffer } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala new file mode 100644 index 000000000000..92e45224db81 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer + +/** + * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer + */ +private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { + + def toByteBuffer: ByteBuffer = { + return ByteBuffer.wrap(buf, 0, count) + } +} From 871e85d9c14c6b19068cc732951a8ae8db61b411 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 7 Dec 2015 13:16:47 -0800 Subject: [PATCH 1683/1824] [SPARK-11963][DOC] Add docs for QuantileDiscretizer https://issues.apache.org/jira/browse/SPARK-11963 Author: Xusen Yin Closes #9962 from yinxusen/SPARK-11963. --- docs/ml-features.md | 65 +++++++++++++++++ .../ml/JavaQuantileDiscretizerExample.java | 71 +++++++++++++++++++ .../ml/QuantileDiscretizerExample.scala | 49 +++++++++++++ 3 files changed, 185 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 05c2c96c5ec5..b499d6ec3b90 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1705,6 +1705,71 @@ print(output.select("features", "clicked").first()) +## QuantileDiscretizer + +`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned +categorical features. +The bin ranges are chosen by taking a sample of the data and dividing it into roughly equal parts. +The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values. +This attempts to find `numBuckets` partitions based on a sample of the given input data, but it may +find fewer depending on the data sample values. + +Note that the result may be different every time you run it, since the sample strategy behind it is +non-deterministic. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `hour`: + +~~~ + id | hour +----|------ + 0 | 18.0 +----|------ + 1 | 19.0 +----|------ + 2 | 8.0 +----|------ + 3 | 5.0 +----|------ + 4 | 2.2 +~~~ + +`hour` is a continuous feature with `Double` type. We want to turn the continuous feature into +categorical one. Given `numBuckets = 3`, we should get the following DataFrame: + +~~~ + id | hour | result +----|------|------ + 0 | 18.0 | 2.0 +----|------|------ + 1 | 19.0 | 2.0 +----|------|------ + 2 | 8.0 | 1.0 +----|------|------ + 3 | 5.0 | 1.0 +----|------|------ + 4 | 2.2 | 0.0 +~~~ + +
    +
    + +Refer to the [QuantileDiscretizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.QuantileDiscretizer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala %} +
    + +
    + +Refer to the [QuantileDiscretizer Java docs](api/java/org/apache/spark/ml/feature/QuantileDiscretizer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java %} +
    +
    + # Feature Selectors ## VectorSlicer diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java new file mode 100644 index 000000000000..251ae79d9a10 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.QuantileDiscretizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaQuantileDiscretizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaQuantileDiscretizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize( + Arrays.asList( + RowFactory.create(0, 18.0), + RowFactory.create(1, 19.0), + RowFactory.create(2, 8.0), + RowFactory.create(3, 5.0), + RowFactory.create(4, 2.2) + ) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("hour", DataTypes.DoubleType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + QuantileDiscretizer discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3); + + DataFrame result = discretizer.fit(df).transform(df); + result.show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala new file mode 100644 index 000000000000..8f29b7eaa6d2 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.QuantileDiscretizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object QuantileDiscretizerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("QuantileDiscretizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // $example on$ + val data = Array((0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2)) + val df = sc.parallelize(data).toDF("id", "hour") + + val discretizer = new QuantileDiscretizer() + .setInputCol("hour") + .setOutputCol("result") + .setNumBuckets(3) + + val result = discretizer.fit(df).transform(df) + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 84b809445f39b9030f272528bdaa39d1559cbc6e Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 7 Dec 2015 14:58:09 -0800 Subject: [PATCH 1684/1824] [SPARK-11884] Drop multiple columns in the DataFrame API See the thread Ben started: http://search-hadoop.com/m/q3RTtveEuhjsr7g/ This PR adds drop() method to DataFrame which accepts multiple column names Author: tedyu Closes #9862 from ted-yu/master. --- .../org/apache/spark/sql/DataFrame.scala | 24 ++++++++++++------- .../org/apache/spark/sql/DataFrameSuite.scala | 7 ++++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index eb8700369275..243a8c853f90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1261,16 +1261,24 @@ class DataFrame private[sql]( * @since 1.4.0 */ def drop(colName: String): DataFrame = { + drop(Seq(colName) : _*) + } + + /** + * Returns a new [[DataFrame]] with columns dropped. + * This is a no-op if schema doesn't contain column name(s). + * @group dfops + * @since 1.6.0 + */ + @scala.annotation.varargs + def drop(colNames: String*): DataFrame = { val resolver = sqlContext.analyzer.resolver - val shouldDrop = schema.exists(f => resolver(f.name, colName)) - if (shouldDrop) { - val colsAfterDrop = schema.filter { field => - val name = field.name - !resolver(name, colName) - }.map(f => Column(f.name)) - select(colsAfterDrop : _*) - } else { + val remainingCols = + schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) + if (remainingCols.size == this.schema.size) { this + } else { + this.select(remainingCols: _*) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 76e9648aa753..605a6549dd68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -378,6 +378,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("value")) } + test("drop columns using drop") { + val src = Seq((0, 2, 3)).toDF("a", "b", "c") + val df = src.drop("a", "b") + checkAnswer(df, Row(3)) + assert(df.schema.map(_.name) === Seq("c")) + } + test("drop unknown column (no-op)") { val df = testData.drop("random") checkAnswer( From 36282f78b888743066843727426c6d806231aa97 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 7 Dec 2015 15:01:00 -0800 Subject: [PATCH 1685/1824] [SPARK-12184][PYTHON] Make python api doc for pivot consistant with scala doc In SPARK-11946 the API for pivot was changed a bit and got updated doc, the doc changes were not made for the python api though. This PR updates the python doc to be consistent. Author: Andrew Ray Closes #10176 from aray/sql-pivot-python-doc. --- python/pyspark/sql/group.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 1911588309af..9ca303a974cd 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -169,16 +169,20 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): - """Pivots a column of the current DataFrame and perform the specified aggregation. + """ + Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + There are two versions of pivot function: one that requires the caller to specify the list + of distinct values to pivot on, and one that does not. The latter is more concise but less + efficient, because Spark needs to first compute the list of distinct values internally. - :param pivot_col: Column to pivot - :param values: Optional list of values of pivot column that will be translated to columns in - the output DataFrame. If values are not provided the method will do an immediate call - to .distinct() on the pivot column. + :param pivot_col: Name of the column to pivot. + :param values: List of values that will be translated to columns in the output DataFrame. + // Compute the sum of earnings for each year by course with each course as a separate column >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + // Or without specifying column values (less efficient) >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ From 3e7e05f5ee763925ed60410d7de04cf36b723de1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 7 Dec 2015 16:37:09 -0800 Subject: [PATCH 1686/1824] [SPARK-12160][MLLIB] Use SQLContext.getOrCreate in MLlib Switched from using SQLContext constructor to using getOrCreate, mainly in model save/load methods. This covers all instances in spark.mllib. There were no uses of the constructor in spark.ml. CC: mengxr yhuai Author: Joseph K. Bradley Closes #10161 from jkbradley/mllib-sqlcontext-fix. --- .../apache/spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++--- .../apache/spark/mllib/classification/NaiveBayes.scala | 8 ++++---- .../classification/impl/GLMClassificationModel.scala | 4 ++-- .../spark/mllib/clustering/GaussianMixtureModel.scala | 4 ++-- .../org/apache/spark/mllib/clustering/KMeansModel.scala | 4 ++-- .../spark/mllib/clustering/PowerIterationClustering.scala | 4 ++-- .../org/apache/spark/mllib/feature/ChiSqSelector.scala | 4 ++-- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- .../mllib/recommendation/MatrixFactorizationModel.scala | 4 ++-- .../spark/mllib/regression/IsotonicRegression.scala | 4 ++-- .../spark/mllib/regression/impl/GLMRegressionModel.scala | 4 ++-- .../apache/spark/mllib/tree/model/DecisionTreeModel.scala | 4 ++-- .../spark/mllib/tree/model/treeEnsembleModels.scala | 4 ++-- 13 files changed, 29 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 54b03a9f9028..2aa6aec0b434 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1191,7 +1191,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = { // We use DataFrames for serialization of IndexedRows to Python, // so return a DataFrame. - val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext) + val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext) sqlContext.createDataFrame(indexedRowMatrix.rows) } @@ -1201,7 +1201,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = { // We use DataFrames for serialization of MatrixEntry entries to // Python, so return a DataFrame. - val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext) + val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext) sqlContext.createDataFrame(coordinateMatrix.entries) } @@ -1211,7 +1211,7 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { // We use DataFrames for serialization of sub-matrix blocks to // Python, so return a DataFrame. - val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext) + val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext) sqlContext.createDataFrame(blockMatrix.blocks) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a956084ae06e..aef9ef2cb052 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -192,7 +192,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -208,7 +208,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -239,7 +239,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -254,7 +254,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index fe09f6b75d28..2910c027ae06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -51,7 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -74,7 +74,7 @@ private[classification] object GLMClassificationModel { */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 2115f7d99c18..74d13e4f7794 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -145,7 +145,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -162,7 +162,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a74158498272..91fa9b0d3590 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -124,7 +124,7 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) @@ -137,7 +137,7 @@ object KMeansModel extends Loader[KMeansModel] { def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 7cd9b08fa8e0..bb1804505948 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -70,7 +70,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( @@ -84,7 +84,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index d4d022afde05..eaa99cfe82e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -134,7 +134,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -150,7 +150,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { def load(sc: SparkContext, path: String): ChiSqSelectorModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b693f3c8e4bd..23b1514e3080 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -587,7 +587,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def load(sc: SparkContext, path: String): Word2VecModel = { val dataPath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -599,7 +599,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val vectorSize = model.values.head.size diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 46562eb2ad0f..0dc40483dd0f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -353,7 +353,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) @@ -364,7 +364,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index ec78ea24539b..f235089873ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -185,7 +185,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -198,7 +198,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 317d3a570263..02af281fb726 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // Create JSON metadata. @@ -71,7 +71,7 @@ private[regression] object GLMRegressionModel { */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 54c136aecf66..89c470d57343 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -201,7 +201,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { } def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -242,7 +242,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 90e032e3d984..3f427f0be3af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -408,7 +408,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ // SPARK-6120: We do a hacky check here so users understand why save() is failing @@ -468,7 +468,7 @@ private[tree] object TreeEnsembleModel extends Logging { path: String, treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) - val sqlContext = new SQLContext(sc) + val sqlContext = SQLContext.getOrCreate(sc) val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) From 78209b0ccaf3f22b5e2345dfb2b98edfdb746819 Mon Sep 17 00:00:00 2001 From: somideshmukh Date: Mon, 7 Dec 2015 23:26:34 -0800 Subject: [PATCH 1687/1824] [SPARK-11551][DOC][EXAMPLE] Replace example code in ml-features.md using include_example Made new patch contaning only markdown examples moved to exmaple/folder. Ony three java code were not shfted since they were contaning compliation error ,these classes are 1)StandardScale 2)NormalizerExample 3)VectorIndexer Author: Xusen Yin Author: somideshmukh Closes #10002 from somideshmukh/SomilBranch1.33. --- docs/ml-features.md | 1109 +---------------- .../examples/ml/JavaBinarizerExample.java | 68 + .../examples/ml/JavaBucketizerExample.java | 70 ++ .../spark/examples/ml/JavaDCTExample.java | 65 + .../ml/JavaElementwiseProductExample.java | 75 ++ .../examples/ml/JavaMinMaxScalerExample.java | 50 + .../spark/examples/ml/JavaNGramExample.java | 71 ++ .../examples/ml/JavaNormalizerExample.java | 52 + .../examples/ml/JavaOneHotEncoderExample.java | 77 ++ .../spark/examples/ml/JavaPCAExample.java | 71 ++ .../ml/JavaPolynomialExpansionExample.java | 71 ++ .../examples/ml/JavaRFormulaExample.java | 69 + .../ml/JavaStandardScalerExample.java | 53 + .../ml/JavaStopWordsRemoverExample.java | 65 + .../examples/ml/JavaStringIndexerExample.java | 66 + .../examples/ml/JavaTokenizerExample.java | 75 ++ .../ml/JavaVectorAssemblerExample.java | 67 + .../examples/ml/JavaVectorIndexerExample.java | 60 + .../examples/ml/JavaVectorSlicerExample.java | 73 ++ .../src/main/python/ml/binarizer_example.py | 43 + .../src/main/python/ml/bucketizer_example.py | 42 + .../python/ml/elementwise_product_example.py | 39 + examples/src/main/python/ml/n_gram_example.py | 42 + .../src/main/python/ml/normalizer_example.py | 41 + .../main/python/ml/onehot_encoder_example.py | 47 + examples/src/main/python/ml/pca_example.py | 42 + .../python/ml/polynomial_expansion_example.py | 43 + .../src/main/python/ml/rformula_example.py | 44 + .../main/python/ml/standard_scaler_example.py | 42 + .../python/ml/stopwords_remover_example.py | 40 + .../main/python/ml/string_indexer_example.py | 39 + .../src/main/python/ml/tokenizer_example.py | 44 + .../python/ml/vector_assembler_example.py | 42 + .../main/python/ml/vector_indexer_example.py | 39 + .../spark/examples/ml/BinarizerExample.scala | 48 + .../spark/examples/ml/BucketizerExample.scala | 51 + .../apache/spark/examples/ml/DCTExample.scala | 54 + .../ml/ElementWiseProductExample.scala | 53 + .../examples/ml/MinMaxScalerExample.scala | 49 + .../spark/examples/ml/NGramExample.scala | 47 + .../spark/examples/ml/NormalizerExample.scala | 50 + .../examples/ml/OneHotEncoderExample.scala | 58 + .../apache/spark/examples/ml/PCAExample.scala | 54 + .../ml/PolynomialExpansionExample.scala | 53 + .../spark/examples/ml/RFormulaExample.scala | 49 + .../examples/ml/StandardScalerExample.scala | 51 + .../examples/ml/StopWordsRemoverExample.scala | 48 + .../examples/ml/StringIndexerExample.scala | 49 + .../spark/examples/ml/TokenizerExample.scala | 54 + .../examples/ml/VectorAssemblerExample.scala | 49 + .../examples/ml/VectorIndexerExample.scala | 53 + .../examples/ml/VectorSlicerExample.scala | 58 + 52 files changed, 2806 insertions(+), 1058 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java create mode 100644 examples/src/main/python/ml/binarizer_example.py create mode 100644 examples/src/main/python/ml/bucketizer_example.py create mode 100644 examples/src/main/python/ml/elementwise_product_example.py create mode 100644 examples/src/main/python/ml/n_gram_example.py create mode 100644 examples/src/main/python/ml/normalizer_example.py create mode 100644 examples/src/main/python/ml/onehot_encoder_example.py create mode 100644 examples/src/main/python/ml/pca_example.py create mode 100644 examples/src/main/python/ml/polynomial_expansion_example.py create mode 100644 examples/src/main/python/ml/rformula_example.py create mode 100644 examples/src/main/python/ml/standard_scaler_example.py create mode 100644 examples/src/main/python/ml/stopwords_remover_example.py create mode 100644 examples/src/main/python/ml/string_indexer_example.py create mode 100644 examples/src/main/python/ml/tokenizer_example.py create mode 100644 examples/src/main/python/ml/vector_assembler_example.py create mode 100644 examples/src/main/python/ml/vector_indexer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index b499d6ec3b90..5105a948fec8 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -170,25 +170,7 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} - -val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - -val tokenized = tokenizer.transform(sentenceDataFrame) -tokenized.select("words", "label").take(3).foreach(println) -val regexTokenized = regexTokenizer.transform(sentenceDataFrame) -regexTokenized.select("words", "label").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %}
    @@ -197,44 +179,7 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); -for (Row r : wordsDataFrame.select("words", "label").take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); -} - -RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %}
    @@ -243,21 +188,7 @@ Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.featu the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Tokenizer, RegexTokenizer - -sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsDataFrame = tokenizer.transform(sentenceDataFrame) -for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) -regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") -# alternatively, pattern="\\w+", gaps(False) -{% endhighlight %} +{% include_example python/ml/tokenizer_example.py %}
    @@ -306,19 +237,7 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StopWordsRemover - -val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") -val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) -)).toDF("id", "raw") - -remover.transform(dataSet).show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %}
    @@ -326,34 +245,7 @@ remover.transform(dataSet).show() Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) -)); -StructType schema = new StructType(new StructField[] { - new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame dataset = jsql.createDataFrame(rdd, schema); - -remover.transform(dataset).show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %}
    @@ -361,17 +253,7 @@ remover.transform(dataset).show(); Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StopWordsRemover - -sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) -], ["label", "raw"]) - -remover = StopWordsRemover(inputCol="raw", outputCol="filtered") -remover.transform(sentenceData).show(truncate=False) -{% endhighlight %} +{% include_example python/ml/stopwords_remover_example.py %}
    @@ -388,19 +270,7 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.NGram - -val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) -)).toDF("label", "words") - -val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") -val ngramDataFrame = ngram.transform(wordDataFrame) -ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %}
    @@ -408,38 +278,7 @@ ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(pri Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); -NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); -DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); -for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %}
    @@ -447,19 +286,7 @@ for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import NGram - -wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) -], ["label", "words"]) -ngram = NGram(inputCol="words", outputCol="ngrams") -ngramDataFrame = ngram.transform(wordDataFrame) -for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) -{% endhighlight %} +{% include_example python/ml/n_gram_example.py %}
    @@ -476,26 +303,7 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Binarizer -import org.apache.spark.sql.DataFrame - -val data = Array( - (0, 0.1), - (1, 0.8), - (2, 0.2) -) -val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - -val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - -val binarizedDataFrame = binarizer.transform(dataFrame) -val binarizedFeatures = binarizedDataFrame.select("binarized_feature") -binarizedFeatures.collect().foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %}
    @@ -503,40 +311,7 @@ binarizedFeatures.collect().foreach(println) Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); -Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); -DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); -DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); -for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %}
    @@ -544,20 +319,7 @@ for (Row r : binarizedFeatures.collect()) { Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Binarizer - -continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) -], ["label", "feature"]) -binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") -binarizedDataFrame = binarizer.transform(continuousDataFrame) -binarizedFeatures = binarizedDataFrame.select("binarized_feature") -for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) -{% endhighlight %} +{% include_example python/ml/binarizer_example.py %}
    @@ -571,25 +333,7 @@ for binarized_feature, in binarizedFeatures.collect(): Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) -val pcaDF = pca.transform(df) -val result = pcaDF.select("pcaFeatures") -result.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %}
    @@ -597,42 +341,7 @@ result.show() Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.PCA -import org.apache.spark.ml.feature.PCAModel -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); -DataFrame result = pca.transform(df).select("pcaFeatures"); -result.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %}
    @@ -640,19 +349,7 @@ result.show(); Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] -df = sqlContext.createDataFrame(data,["features"]) -pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") -model = pca.fit(df) -result = model.transform(df).select("pcaFeatures") -result.show(truncate=False) -{% endhighlight %} +{% include_example python/ml/pca_example.py %}
    @@ -666,23 +363,7 @@ result.show(truncate=False) Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) -val polyDF = polynomialExpansion.transform(df) -polyDF.select("polyFeatures").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %}
    @@ -690,43 +371,7 @@ polyDF.select("polyFeatures").take(3).foreach(println) Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DataFrame polyDF = polyExpansion.transform(df); -Row[] row = polyDF.select("polyFeatures").take(3); -for (Row r : row) { - System.out.println(r.get(0)); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %}
    @@ -734,20 +379,7 @@ for (Row r : row) { Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors - -df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) -px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") -polyDF = px.transform(df) -for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) -{% endhighlight %} +{% include_example python/ml/polynomial_expansion_example.py %}
    @@ -771,22 +403,7 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors - -val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) -val dctDf = dct.transform(df) -dctDf.select("featuresDCT").show(3) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %}
    @@ -794,39 +411,7 @@ dctDf.select("featuresDCT").show(3) Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); -DataFrame dctDf = dct.transform(df); -dctDf.select("featuresDCT").show(3); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}}
    @@ -881,18 +466,7 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StringIndexer - -val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) -).toDF("id", "category") -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") -val indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %}
    @@ -900,37 +474,7 @@ indexed.show() Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[] { - createStructField("id", DoubleType, false), - createStructField("category", StringType, false) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); -DataFrame indexed = indexer.fit(df).transform(df); -indexed.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %}
    @@ -938,16 +482,7 @@ indexed.show(); Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StringIndexer - -df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) -indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example python/ml/string_indexer_example.py %}
    @@ -961,29 +496,7 @@ indexed.show() Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} - -val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -)).toDF("id", "category") - -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) -val indexed = indexer.transform(df) - -val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") -val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
    @@ -991,45 +504,7 @@ encoded.select("id", "categoryVec").foreach(println) Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); -DataFrame indexed = indexer.transform(df); - -OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); -DataFrame encoded = encoder.transform(indexed); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
    @@ -1037,24 +512,7 @@ DataFrame encoded = encoder.transform(indexed); Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import OneHotEncoder, StringIndexer - -df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -], ["id", "category"]) - -stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -model = stringIndexer.fit(df) -indexed = model.transform(df) -encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") -encoded = encoder.transform(indexed) -{% endhighlight %} +{% include_example python/ml/onehot_encoder_example.py %}
    @@ -1078,23 +536,7 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.VectorIndexer - -val data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) -val indexerModel = indexer.fit(data) -val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet -println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - -// Create new column "indexed" with categorical values transformed to indices -val indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %}
    @@ -1102,30 +544,7 @@ val indexedData = indexerModel.transform(data) Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame data = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); -VectorIndexerModel indexerModel = indexer.fit(data); -Map> categoryMaps = indexerModel.javaCategoryMaps(); -System.out.print("Chose " + categoryMaps.size() + "categorical features:"); -for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); -} -System.out.println(); - -// Create new column "indexed" with categorical values transformed to indices -DataFrame indexedData = indexerModel.transform(data); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %}
    @@ -1133,17 +552,7 @@ DataFrame indexedData = indexerModel.transform(data); Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import VectorIndexer - -data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) -indexerModel = indexer.fit(data) - -# Create new column "indexed" with categorical values transformed to indices -indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example python/ml/vector_indexer_example.py %}
    @@ -1160,22 +569,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Normalizer - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -// Normalize each Vector using $L^1$ norm. -val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) -val l1NormData = normalizer.transform(dataFrame) - -// Normalize each Vector using $L^\infty$ norm. -val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    @@ -1183,24 +577,7 @@ val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.Positi Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Normalize each Vector using $L^1$ norm. -Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); -DataFrame l1NormData = normalizer.transform(dataFrame); - -// Normalize each Vector using $L^\infty$ norm. -DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    @@ -1208,19 +585,7 @@ DataFrame lInfNormData = Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Normalizer - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -# Normalize each Vector using $L^1$ norm. -normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) -l1NormData = normalizer.transform(dataFrame) - -# Normalize each Vector using $L^\infty$ norm. -lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) -{% endhighlight %} +{% include_example python/ml/normalizer_example.py %}
    @@ -1244,23 +609,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StandardScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - -// Compute summary statistics by fitting the StandardScaler -val scalerModel = scaler.fit(dataFrame) - -// Normalize each feature to have unit standard deviation. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    @@ -1268,25 +617,7 @@ val scaledData = scalerModel.transform(dataFrame) Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - -// Compute summary statistics by fitting the StandardScaler -StandardScalerModel scalerModel = scaler.fit(dataFrame); - -// Normalize each feature to have unit standard deviation. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    @@ -1294,20 +625,7 @@ DataFrame scaledData = scalerModel.transform(dataFrame); Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StandardScaler - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - -# Compute summary statistics by fitting the StandardScaler -scalerModel = scaler.fit(dataFrame) - -# Normalize each feature to have unit standard deviation. -scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/standard_scaler_example.py %}
    @@ -1337,21 +655,7 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.MinMaxScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - -// Compute summary statistics and generate MinMaxScalerModel -val scalerModel = scaler.fit(dataFrame) - -// rescale each feature to range [min, max]. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %}
    @@ -1360,24 +664,7 @@ Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMa and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) for more details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); - -// Compute summary statistics and generate MinMaxScalerModel -MinMaxScalerModel scalerModel = scaler.fit(dataFrame); - -// rescale each feature to range [min, max]. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %}
    @@ -1401,23 +688,7 @@ The following example demonstrates how to bucketize a column of `Double`s into a Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Bucketizer -import org.apache.spark.sql.DataFrame - -val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - -val data = Array(-0.5, -0.3, 0.0, 0.2) -val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - -val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - -// Transform original data into its bucket index. -val bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    @@ -1425,38 +696,7 @@ val bucketedData = bucketizer.transform(dataFrame) Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame dataFrame = jsql.createDataFrame(data, schema); - -Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - -// Transform original data into its bucket index. -DataFrame bucketedData = bucketizer.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    @@ -1464,19 +704,7 @@ DataFrame bucketedData = bucketizer.transform(dataFrame); Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Bucketizer - -splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - -data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] -dataFrame = sqlContext.createDataFrame(data, ["features"]) - -bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - -# Transform original data into its bucket index. -bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/bucketizer_example.py %}
    @@ -1508,25 +736,7 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors - -// Create some vector data; also works for sparse vectors -val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - -val transformingVector = Vectors.dense(0.0, 1.0, 2.0) -val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %}
    @@ -1534,41 +744,7 @@ transformer.transform(dataFrame).show() Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -// Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) -)); -List fields = new ArrayList(2); -fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); -fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); -StructType schema = DataTypes.createStructType(fields); -DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); -Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); -ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %}
    @@ -1576,19 +752,8 @@ transformer.transform(dataFrame).show(); Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] -df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") -transformer.transform(df).show() - -{% endhighlight %} +{% include_example python/ml/elementwise_product_example.py %}
    - ## VectorAssembler @@ -1632,19 +797,7 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.feature.VectorAssembler - -val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) -).toDF("id", "hour", "mobile", "userFeatures", "clicked") -val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") -val output = assembler.transform(dataset) -println(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %}
    @@ -1652,36 +805,7 @@ println(output.select("features", "clicked").first()) Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) -}); -Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); -JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - -DataFrame output = assembler.transform(dataset); -System.out.println(output.select("features", "clicked").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %}
    @@ -1689,19 +813,7 @@ System.out.println(output.select("features", "clicked").first()); Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) for more details on the API. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler - -dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) -assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") -output = assembler.transform(dataset) -print(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example python/ml/vector_assembler_example.py %}
    @@ -1831,33 +943,7 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - -val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3, 0.0) -) - -val defaultAttr = NumericAttribute.defaultAttr -val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) -val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - -val dataRDD = sc.parallelize(data).map(Row.apply) -val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) - -val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - -slicer.setIndices(1).setNames("f3") -// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - -val output = slicer.transform(dataset) -println(output.select("userFeatures", "features").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %}
    @@ -1865,41 +951,7 @@ println(output.select("userFeatures", "features").first()) Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") -}; -AttributeGroup group = new AttributeGroup("userFeatures", attrs); - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) -)); - -DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); - -VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); - -vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); -// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - -DataFrame output = vectorSlicer.transform(dataset); - -System.out.println(output.select("userFeatures", "features").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %}
    @@ -1936,21 +988,7 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.RFormula - -val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) -)).toDF("id", "country", "hour", "clicked") -val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") -val output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %}
    @@ -1958,38 +996,7 @@ output.select("features", "label").show() Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) -}); -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) -)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - -DataFrame output = formula.fit(dataset).transform(dataset); -output.select("features", "label").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %}
    @@ -1997,21 +1004,7 @@ output.select("features", "label").show(); Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import RFormula - -dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) -formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") -output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example python/ml/rformula_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java new file mode 100644 index 000000000000..9698cac50437 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBinarizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); + Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); + DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); + DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); + for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java new file mode 100644 index 000000000000..b06a23e76d60 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Bucketizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBucketizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataFrame = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + + // Transform original data into its bucket index. + DataFrame bucketedData = bucketizer.transform(dataFrame); + // $example off$ + jsc.stop(); + } +} + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java new file mode 100644 index 000000000000..35c0d534a45e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaDCTExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + DataFrame df = jsql.createDataFrame(data, schema); + DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); + DataFrame dctDf = dct.transform(df); + dctDf.select("featuresDCT").show(3); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java new file mode 100644 index 000000000000..2898accec61b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaElementwiseProductExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Create some vector data; also works for sparse vectors + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) + )); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); + + StructType schema = DataTypes.createStructType(fields); + + DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); + + Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); + + ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java new file mode 100644 index 000000000000..138b3ab6aba4 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaMinMaxScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + + // Compute summary statistics and generate MinMaxScalerModel + MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + + // rescale each feature to range [min, max]. + DataFrame scaledData = scalerModel.transform(dataFrame); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java new file mode 100644 index 000000000000..8fd75ed8b5f4 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaNGramExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField( + "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); + + NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); + + DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); + + for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java new file mode 100644 index 000000000000..6283a355e1fe --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaNormalizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Normalize each Vector using $L^1$ norm. + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); + + DataFrame l1NormData = normalizer.transform(dataFrame); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java new file mode 100644 index 000000000000..172a9cc6feb2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaOneHotEncoderExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); + DataFrame encoded = encoder.transform(indexed); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java new file mode 100644 index 000000000000..8282fab084f3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PCA; +import org.apache.spark.ml.feature.PCAModel; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + + PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); + + DataFrame result = pca.transform(df).select("pcaFeatures"); + result.show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java new file mode 100644 index 000000000000..668f71e64056 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PolynomialExpansion; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPolynomialExpansionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + DataFrame polyDF = polyExpansion.transform(df); + + Row[] row = polyDF.select("polyFeatures").take(3); + for (Row r : row) { + System.out.println(r.get(0)); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java new file mode 100644 index 000000000000..1e1062b541ad --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaRFormulaExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) + }); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) + )); + + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + DataFrame output = formula.fit(dataset).transform(dataset); + output.select("features", "label").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java new file mode 100644 index 000000000000..0cbdc97e8ae3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaStandardScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java new file mode 100644 index 000000000000..b6b201c6b68d --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaStopWordsRemoverExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField( + "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(rdd, schema); + remover.transform(dataset).show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java new file mode 100644 index 000000000000..05d12c1e702f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaStringIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("category", StringType, false) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); + DataFrame indexed = indexer.fit(df).transform(df); + indexed.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java new file mode 100644 index 000000000000..617dc3f66e3b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTokenizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + + DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + for (Row r : wordsDataFrame.select("words", "label"). take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); + } + + RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java new file mode 100644 index 000000000000..7e230b5897c1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorAssemblerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + DataFrame output = assembler.transform(dataset); + System.out.println(output.select("features", "clicked").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java new file mode 100644 index 000000000000..06b4bf6bf8ff --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaVectorIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); + VectorIndexerModel indexerModel = indexer.fit(data); + + Map> categoryMaps = indexerModel.javaCategoryMaps(); + System.out.print("Chose " + categoryMaps.size() + " categorical features:"); + + for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); + } + System.out.println(); + + // Create new column "indexed" with categorical values transformed to indices + DataFrame indexedData = indexerModel.transform(data); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java new file mode 100644 index 000000000000..4d5cb04ff5e2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.ml.feature.VectorSlicer; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaVectorSlicerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + + DataFrame output = vectorSlicer.transform(dataset); + + System.out.println(output.select("userFeatures", "features").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py new file mode 100644 index 000000000000..960ad208be12 --- /dev/null +++ b/examples/src/main/python/ml/binarizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Binarizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinarizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) + ], ["label", "feature"]) + binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") + binarizedDataFrame = binarizer.transform(continuousDataFrame) + binarizedFeatures = binarizedDataFrame.select("binarized_feature") + for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py new file mode 100644 index 000000000000..a12750aa9248 --- /dev/null +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Bucketizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BucketizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + + data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] + dataFrame = sqlContext.createDataFrame(data, ["features"]) + + bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + + # Transform original data into its bucket index. + bucketedData = bucketizer.transform(dataFrame) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py new file mode 100644 index 000000000000..c85cb0d89543 --- /dev/null +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ElementwiseProductExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] + df = sqlContext.createDataFrame(data, ["vector"]) + transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") + transformer.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py new file mode 100644 index 000000000000..f2d85f53e721 --- /dev/null +++ b/examples/src/main/python/ml/n_gram_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import NGram +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NGramExample") + sqlContext = SQLContext(sc) + + # $example on$ + wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) + ], ["label", "words"]) + ngram = NGram(inputCol="words", outputCol="ngrams") + ngramDataFrame = ngram.transform(wordDataFrame) + for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py new file mode 100644 index 000000000000..833d93e976a7 --- /dev/null +++ b/examples/src/main/python/ml/normalizer_example.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Normalizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NormalizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Normalize each Vector using $L^1$ norm. + normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) + l1NormData = normalizer.transform(dataFrame) + + # Normalize each Vector using $L^\infty$ norm. + lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py new file mode 100644 index 000000000000..7529dfd09213 --- /dev/null +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import OneHotEncoder, StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="OneHotEncoderExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + ], ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") + encoded = encoder.transform(indexed) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py new file mode 100644 index 000000000000..8b66140a40a7 --- /dev/null +++ b/examples/src/main/python/ml/pca_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PCAExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + df = sqlContext.createDataFrame(data,["features"]) + pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") + model = pca.fit(df) + result = model.transform(df).select("pcaFeatures") + result.show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py new file mode 100644 index 000000000000..030a6132a451 --- /dev/null +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PolynomialExpansionExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(Vectors.dense([-2.0, 2.3]), ), + (Vectors.dense([0.0, 0.0]), ), + (Vectors.dense([0.6, -1.1]), )], + ["features"]) + px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") + polyDF = px.transform(df) + for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py new file mode 100644 index 000000000000..b544a1470076 --- /dev/null +++ b/examples/src/main/python/ml/rformula_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import RFormula +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="RFormulaExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) + formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") + output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py new file mode 100644 index 000000000000..139acecbfb53 --- /dev/null +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StandardScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StandardScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + + # Compute summary statistics by fitting the StandardScaler + scalerModel = scaler.fit(dataFrame) + + # Normalize each feature to have unit standard deviation. + scaledData = scalerModel.transform(dataFrame) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py new file mode 100644 index 000000000000..01f94af8ca75 --- /dev/null +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StopWordsRemover +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StopWordsRemoverExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) + ], ["label", "raw"]) + + remover = StopWordsRemover(inputCol="raw", outputCol="filtered") + remover.transform(sentenceData).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py new file mode 100644 index 000000000000..58a8cb5d56b7 --- /dev/null +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StringIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + indexed = indexer.fit(df).transform(df) + indexed.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py new file mode 100644 index 000000000000..ce9b225be535 --- /dev/null +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Tokenizer, RegexTokenizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="TokenizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsDataFrame = tokenizer.transform(sentenceDataFrame) + for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) + regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") + # alternatively, pattern="\\w+", gaps(False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py new file mode 100644 index 000000000000..04f64839f188 --- /dev/null +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorAssemblerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + output = assembler.transform(dataset) + print(output.select("features", "clicked").first()) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py new file mode 100644 index 000000000000..cc00d1454f2e --- /dev/null +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import VectorIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) + indexerModel = indexer.fit(data) + + # Create new column "indexed" with categorical values transformed to indices + indexedData = indexerModel.transform(data) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala new file mode 100644 index 000000000000..e724aa587294 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Binarizer +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} + +object BinarizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BinarizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + + val binarizedDataFrame = binarizer.transform(dataFrame) + val binarizedFeatures = binarizedDataFrame.select("binarized_feature") + binarizedFeatures.collect().foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala new file mode 100644 index 000000000000..30c2776d3968 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Bucketizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object BucketizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BucketizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + + val data = Array(-0.5, -0.3, 0.0, 0.2) + val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + + // Transform original data into its bucket index. + val bucketedData = bucketizer.transform(dataFrame) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala new file mode 100644 index 000000000000..314c2c28a2a1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object DCTExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DCTExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) + + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) + + val dctDf = dct.transform(df) + dctDf.select("featuresDCT").show(3) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala new file mode 100644 index 000000000000..ac50bb7b2b15 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ElementwiseProductExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ElementwiseProductExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Create some vector data; also works for sparse vectors + val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + + val transformingVector = Vectors.dense(0.0, 1.0, 2.0) + val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala new file mode 100644 index 000000000000..dac3679a5bf7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object MinMaxScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MinMaxScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + + // Compute summary statistics and generate MinMaxScalerModel + val scalerModel = scaler.fit(dataFrame) + + // rescale each feature to range [min, max]. + val scaledData = scalerModel.transform(dataFrame) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala new file mode 100644 index 000000000000..8a85f71b56f3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.NGram +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NGramExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NGramExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) + )).toDF("label", "words") + + val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") + val ngramDataFrame = ngram.transform(wordDataFrame) + ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala new file mode 100644 index 000000000000..17571f0aad79 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Normalizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NormalizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NormalizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Normalize each Vector using $L^1$ norm. + val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) + + val l1NormData = normalizer.transform(dataFrame) + + // Normalize each Vector using $L^\infty$ norm. + val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala new file mode 100644 index 000000000000..4512736943dd --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object OneHotEncoderExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("OneHotEncoderExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val encoder = new OneHotEncoder().setInputCol("categoryIndex"). + setOutputCol("categoryVec") + val encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala new file mode 100644 index 000000000000..a18d4f33973d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PCAExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PCAExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) + val pcaDF = pca.transform(df) + val result = pcaDF.select("pcaFeatures") + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala new file mode 100644 index 000000000000..b8e9e6952a5e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PolynomialExpansionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PolynomialExpansionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + val polyDF = polynomialExpansion.transform(df) + polyDF.select("polyFeatures").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala new file mode 100644 index 000000000000..286866edea50 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.RFormula +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RFormulaExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RFormulaExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) + )).toDF("id", "country", "hour", "clicked") + val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") + val output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala new file mode 100644 index 000000000000..646ce0f13ecf --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StandardScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StandardScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + + // Compute summary statistics by fitting the StandardScaler. + val scalerModel = scaler.fit(dataFrame) + + // Normalize each feature to have unit standard deviation. + val scaledData = scalerModel.transform(dataFrame) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala new file mode 100644 index 000000000000..655ffce08d3a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StopWordsRemover +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StopWordsRemoverExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StopWordsRemoverExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + + val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) + )).toDF("id", "raw") + + remover.transform(dataSet).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala new file mode 100644 index 000000000000..1be8a5f33f7c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StringIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StringIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StringIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + ).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + + val indexed = indexer.fit(df).transform(df) + indexed.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala new file mode 100644 index 000000000000..01e0d1388a2f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object TokenizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TokenizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + + val tokenized = tokenizer.transform(sentenceDataFrame) + tokenized.select("words", "label").take(3).foreach(println) + val regexTokenized = regexTokenizer.transform(sentenceDataFrame) + regexTokenized.select("words", "label").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala new file mode 100644 index 000000000000..d527924419f8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorAssemblerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorAssemblerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + val output = assembler.transform(dataset) + println(output.select("features", "clicked").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala new file mode 100644 index 000000000000..14279d610fda --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) + + val indexerModel = indexer.fit(data) + + val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet + println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + + // Create new column "indexed" with categorical values transformed to indices + val indexedData = indexerModel.transform(data) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala new file mode 100644 index 000000000000..04f19829eff8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorSlicerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorSlicerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + + val dataRDD = sc.parallelize(data) + val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) + + val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + + slicer.setIndices(Array(1)).setNames(Array("f3")) + // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + + val output = slicer.transform(dataset) + println(output.select("userFeatures", "features").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 73896588dd3af6ba77c9692cd5120ee32448eb22 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 7 Dec 2015 23:34:16 -0800 Subject: [PATCH 1688/1824] Closes #10098 From 7d05a624510f7299b3dd07f87c203db1ff7caa3e Mon Sep 17 00:00:00 2001 From: Takahashi Hiroshi Date: Mon, 7 Dec 2015 23:46:55 -0800 Subject: [PATCH 1689/1824] [SPARK-10259][ML] Add @since annotation to ml.classification Add since annotation to ml.classification Author: Takahashi Hiroshi Closes #8534 from taishi-oss/issue10259. --- .../DecisionTreeClassifier.scala | 30 +++++++-- .../ml/classification/GBTClassifier.scala | 35 ++++++++-- .../classification/LogisticRegression.scala | 64 +++++++++++++++---- .../MultilayerPerceptronClassifier.scala | 23 +++++-- .../spark/ml/classification/NaiveBayes.scala | 19 ++++-- .../spark/ml/classification/OneVsRest.scala | 24 +++++-- .../RandomForestClassifier.scala | 34 ++++++++-- 7 files changed, 185 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index c478aea44ace..8c4cec132665 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest @@ -36,32 +36,44 @@ import org.apache.spark.sql.DataFrame * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class DecisionTreeClassifier(override val uid: String) +final class DecisionTreeClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("dtc")) // Override parameter setters from parent trait for Java API compatibility. + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) + @Since("1.6.0") override def setSeed(value: Long): this.type = super.setSeed(value) override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { @@ -89,12 +101,15 @@ final class DecisionTreeClassifier(override val uid: String) subsamplingRate = 1.0) } + @Since("1.4.1") override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object DecisionTreeClassifier { /** Accessor for supported impurities: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } @@ -104,12 +119,13 @@ object DecisionTreeClassifier { * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental final class DecisionTreeClassificationModel private[ml] ( - override val uid: String, - override val rootNode: Node, - override val numFeatures: Int, - override val numClasses: Int) + @Since("1.4.0")override val uid: String, + @Since("1.4.0")override val rootNode: Node, + @Since("1.6.0")override val numFeatures: Int, + @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -142,11 +158,13 @@ final class DecisionTreeClassificationModel private[ml] ( } } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeClassificationModel = { copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) .setParent(parent) } + @Since("1.4.0") override def toString: String = { s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 74aef94bf767..cda2bca58c50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel @@ -44,36 +44,47 @@ import org.apache.spark.sql.types.DoubleType * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. */ +@Since("1.4.0") @Experimental -final class GBTClassifier(override val uid: String) +final class GBTClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTParams with TreeClassifierParams with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) /** * The impurity setting is ignored for GBT models. * Individual trees are built using impurity "Variance." */ + @Since("1.4.0") override def setImpurity(value: String): this.type = { logWarning("GBTClassifier.setImpurity should NOT be used") this @@ -81,8 +92,10 @@ final class GBTClassifier(override val uid: String) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = { logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") super.setSeed(value) @@ -90,8 +103,10 @@ final class GBTClassifier(override val uid: String) // Parameters from GBTParams: + @Since("1.4.0") override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) // Parameters for GBTClassifier: @@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String) * (default = logistic) * @group param */ + @Since("1.4.0") val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", @@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String) setDefault(lossType -> "logistic") /** @group setParam */ + @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) /** @group getParam */ + @Since("1.4.0") def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ @@ -145,13 +163,16 @@ final class GBTClassifier(override val uid: String) GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } + @Since("1.4.1") override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object GBTClassifier { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ + @Since("1.4.0") final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) } @@ -164,12 +185,13 @@ object GBTClassifier { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ +@Since("1.6.0") @Experimental final class GBTClassificationModel private[ml]( - override val uid: String, + @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], - override val numFeatures: Int) + @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] with TreeEnsembleModel with Serializable { @@ -182,11 +204,14 @@ final class GBTClassificationModel private[ml]( * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ + @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = this(uid, _trees, _treeWeights, -1) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -205,11 +230,13 @@ final class GBTClassificationModel private[ml]( if (prediction > 0.0) 1.0 else 0.0 } + @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), extra).setParent(parent) } + @Since("1.4.0") override def toString: String = { s"GBTClassificationModel (uid=$uid) with $numTrees trees" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d320d64dd90d..19cc323d5073 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -154,11 +154,14 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas * Currently, this class only supports binary classification. It will support multiclass * in the future. */ +@Since("1.2.0") @Experimental -class LogisticRegression(override val uid: String) +class LogisticRegression @Since("1.2.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams with DefaultParamsWritable with Logging { + @Since("1.4.0") def this() = this(Identifiable.randomUID("logreg")) /** @@ -166,6 +169,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0. * @group setParam */ + @Since("1.2.0") def setRegParam(value: Double): this.type = set(regParam, value) setDefault(regParam -> 0.0) @@ -176,6 +180,7 @@ class LogisticRegression(override val uid: String) * Default is 0.0 which is an L2 penalty. * @group setParam */ + @Since("1.4.0") def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) setDefault(elasticNetParam -> 0.0) @@ -184,6 +189,7 @@ class LogisticRegression(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.2.0") def setMaxIter(value: Int): this.type = set(maxIter, value) setDefault(maxIter -> 100) @@ -193,6 +199,7 @@ class LogisticRegression(override val uid: String) * Default is 1E-6. * @group setParam */ + @Since("1.4.0") def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) @@ -201,6 +208,7 @@ class LogisticRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.4.0") def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -213,11 +221,14 @@ class LogisticRegression(override val uid: String) * Default is true. * @group setParam */ + @Since("1.5.0") def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) + @Since("1.5.0") override def setThreshold(value: Double): this.type = super.setThreshold(value) + @Since("1.5.0") override def getThreshold: Double = super.getThreshold /** @@ -226,11 +237,14 @@ class LogisticRegression(override val uid: String) * Default is empty, so all instances have weight one. * @group setParam */ + @Since("1.6.0") def setWeightCol(value: String): this.type = set(weightCol, value) setDefault(weightCol -> "") + @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds override protected def train(dataset: DataFrame): LogisticRegressionModel = { @@ -384,11 +398,14 @@ class LogisticRegression(override val uid: String) model.setSummary(logRegSummary) } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } +@Since("1.6.0") object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + @Since("1.6.0") override def load(path: String): LogisticRegression = super.load(path) } @@ -396,23 +413,28 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { * :: Experimental :: * Model produced by [[LogisticRegression]]. */ +@Since("1.4.0") @Experimental class LogisticRegressionModel private[ml] ( - override val uid: String, - val coefficients: Vector, - val intercept: Double) + @Since("1.4.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.3.0") val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients + @Since("1.5.0") override def setThreshold(value: Double): this.type = super.setThreshold(value) + @Since("1.5.0") override def getThreshold: Double = super.getThreshold + @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) + @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds /** Margin (rawPrediction) for class label 1. For binary classification only. */ @@ -426,8 +448,10 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } + @Since("1.6.0") override val numFeatures: Int = coefficients.size + @Since("1.3.0") override val numClasses: Int = 2 private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None @@ -436,6 +460,7 @@ class LogisticRegressionModel private[ml] ( * Gets summary of model on training set. An exception is * thrown if `trainingSummary == None`. */ + @Since("1.5.0") def summary: LogisticRegressionTrainingSummary = trainingSummary match { case Some(summ) => summ case None => @@ -451,6 +476,7 @@ class LogisticRegressionModel private[ml] ( } /** Indicates whether a training summary exists for this model instance. */ + @Since("1.5.0") def hasSummary: Boolean = trainingSummary.isDefined /** @@ -493,6 +519,7 @@ class LogisticRegressionModel private[ml] ( Vectors.dense(-m, m) } + @Since("1.4.0") override def copy(extra: ParamMap): LogisticRegressionModel = { val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) @@ -710,12 +737,13 @@ sealed trait LogisticRegressionSummary extends Serializable { * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental +@Since("1.5.0") class BinaryLogisticRegressionTrainingSummary private[classification] ( - predictions: DataFrame, - probabilityCol: String, - labelCol: String, - featuresCol: String, - val objectiveHistory: Array[Double]) + @Since("1.5.0") predictions: DataFrame, + @Since("1.5.0") probabilityCol: String, + @Since("1.5.0") labelCol: String, + @Since("1.6.0") featuresCol: String, + @Since("1.5.0") val objectiveHistory: Array[Double]) extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) with LogisticRegressionTrainingSummary { @@ -731,11 +759,13 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ @Experimental +@Since("1.5.0") class BinaryLogisticRegressionSummary private[classification] ( - @transient override val predictions: DataFrame, - override val probabilityCol: String, - override val labelCol: String, - override val featuresCol: String) extends LogisticRegressionSummary { + @Since("1.5.0") @transient override val predictions: DataFrame, + @Since("1.5.0") override val probabilityCol: String, + @Since("1.5.0") override val labelCol: String, + @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary { + private val sqlContext = predictions.sqlContext import sqlContext.implicits._ @@ -760,6 +790,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * This will change in later Spark versions. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ + @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") /** @@ -768,6 +799,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() /** @@ -777,6 +809,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") /** @@ -785,6 +818,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val fMeasureByThreshold: DataFrame = { binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") } @@ -797,6 +831,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val precisionByThreshold: DataFrame = { binaryMetrics.precisionByThreshold().toDF("threshold", "precision") } @@ -809,6 +844,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * This will change in later Spark versions. */ + @Since("1.5.0") @transient lazy val recallByThreshold: DataFrame = { binaryMetrics.recallByThreshold().toDF("threshold", "recall") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index cd7462596dd9..a691aa005ef5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.classification import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} @@ -104,19 +104,23 @@ private object LabelConverter { * Each layer has sigmoid activation function, output layer has softmax. * Number of inputs has to be equal to the size of feature vectors. * Number of outputs has to be equal to the total number of labels. - * */ +@Since("1.5.0") @Experimental -class MultilayerPerceptronClassifier(override val uid: String) +class MultilayerPerceptronClassifier @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] with MultilayerPerceptronParams { + @Since("1.5.0") def this() = this(Identifiable.randomUID("mlpc")) /** @group setParam */ + @Since("1.5.0") def setLayers(value: Array[Int]): this.type = set(layers, value) /** @group setParam */ + @Since("1.5.0") def setBlockSize(value: Int): this.type = set(blockSize, value) /** @@ -124,6 +128,7 @@ class MultilayerPerceptronClassifier(override val uid: String) * Default is 100. * @group setParam */ + @Since("1.5.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @@ -132,14 +137,17 @@ class MultilayerPerceptronClassifier(override val uid: String) * Default is 1E-4. * @group setParam */ + @Since("1.5.0") def setTol(value: Double): this.type = set(tol, value) /** * Set the seed for weights initialization. * @group setParam */ + @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) /** @@ -173,14 +181,16 @@ class MultilayerPerceptronClassifier(override val uid: String) * @param weights vector of initial weights for the model that consists of the weights of layers * @return prediction model */ +@Since("1.5.0") @Experimental class MultilayerPerceptronClassificationModel private[ml] ( - override val uid: String, - val layers: Array[Int], - val weights: Vector) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val layers: Array[Int], + @Since("1.5.0") val weights: Vector) extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { + @Since("1.6.0") override val numFeatures: Int = layers.head private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) @@ -200,6 +210,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( LabelConverter.decodeLabel(mlpModel.predict(features)) } + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index c512a2cb8bf3..718f49d3aedc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -72,11 +72,14 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). * The input feature values must be nonnegative. */ +@Since("1.5.0") @Experimental -class NaiveBayes(override val uid: String) +class NaiveBayes @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) /** @@ -84,6 +87,7 @@ class NaiveBayes(override val uid: String) * Default is 1.0. * @group setParam */ + @Since("1.5.0") def setSmoothing(value: Double): this.type = set(smoothing, value) setDefault(smoothing -> 1.0) @@ -93,6 +97,7 @@ class NaiveBayes(override val uid: String) * Default is "multinomial" * @group setParam */ + @Since("1.5.0") def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) @@ -102,6 +107,7 @@ class NaiveBayes(override val uid: String) NaiveBayesModel.fromOld(oldModel, this) } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } @@ -119,11 +125,12 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { * @param theta log of class conditional probabilities, whose dimension is C (number of classes) * by D (number of features) */ +@Since("1.5.0") @Experimental class NaiveBayesModel private[ml] ( - override val uid: String, - val pi: Vector, - val theta: Matrix) + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val pi: Vector, + @Since("1.5.0") val theta: Matrix) extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable { @@ -148,8 +155,10 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } + @Since("1.6.0") override val numFeatures: Int = theta.numCols + @Since("1.5.0") override val numClasses: Int = pi.size private def multinomialCalculation(features: Vector) = { @@ -206,10 +215,12 @@ class NaiveBayesModel private[ml] ( } } + @Since("1.5.0") override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } + @Since("1.5.0") override def toString: String = { s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index debc164bf243..08a51109d6c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,7 +21,7 @@ import java.util.UUID import scala.language.existentials -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{Param, ParamMap} @@ -70,17 +70,20 @@ private[ml] trait OneVsRestParams extends PredictorParams { * The i-th model is produced by testing the i-th class (taking label 1) vs the rest * (taking label 0). */ +@Since("1.4.0") @Experimental final class OneVsRestModel private[ml] ( - override val uid: String, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_, _]]) + @Since("1.4.0") override val uid: String, + @Since("1.4.0") labelMetadata: Metadata, + @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -134,6 +137,7 @@ final class OneVsRestModel private[ml] ( .drop(accColName) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) @@ -150,30 +154,39 @@ final class OneVsRestModel private[ml] ( * Each example is scored against all k models and the model with highest score * is picked to label the example. */ +@Since("1.4.0") @Experimental -final class OneVsRest(override val uid: String) +final class OneVsRest @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Estimator[OneVsRestModel] with OneVsRestParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("oneVsRest")) /** @group setParam */ + @Since("1.4.0") def setClassifier(value: Classifier[_, _, _]): this.type = { set(classifier, value.asInstanceOf[ClassifierType]) } /** @group setParam */ + @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } + @Since("1.4.0") override def fit(dataset: DataFrame): OneVsRestModel = { // determine number of classes either from metadata if provided, or via computation. val labelSchema = dataset.schema($(labelCol)) @@ -222,6 +235,7 @@ final class OneVsRest(override val uid: String) copyValues(model) } + @Since("1.4.1") override def copy(extra: ParamMap): OneVsRest = { val copied = defaultCopy(extra).asInstanceOf[OneVsRest] if (isDefined(classifier)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bae329692a68..d6d85ad2533a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} @@ -38,44 +38,59 @@ import org.apache.spark.sql.functions._ * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ +@Since("1.4.0") @Experimental -final class RandomForestClassifier(override val uid: String) +final class RandomForestClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { + @Since("1.4.0") def this() = this(Identifiable.randomUID("rfc")) // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: + @Since("1.4.0") override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + @Since("1.4.0") override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + @Since("1.4.0") override def setMinInstancesPerNode(value: Int): this.type = super.setMinInstancesPerNode(value) + @Since("1.4.0") override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + @Since("1.4.0") override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + @Since("1.4.0") override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + @Since("1.4.0") override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) // Parameters from TreeEnsembleParams: + @Since("1.4.0") override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + @Since("1.4.0") override def setSeed(value: Long): this.type = super.setSeed(value) // Parameters from RandomForestParams: + @Since("1.4.0") override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + @Since("1.4.0") override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) @@ -99,15 +114,19 @@ final class RandomForestClassifier(override val uid: String) new RandomForestClassificationModel(trees, numFeatures, numClasses) } + @Since("1.4.1") override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) } +@Since("1.4.0") @Experimental object RandomForestClassifier { /** Accessor for supported impurity settings: entropy, gini */ + @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies } @@ -120,12 +139,13 @@ object RandomForestClassifier { * @param _trees Decision trees in the ensemble. * Warning: These have null parents. */ +@Since("1.4.0") @Experimental final class RandomForestClassificationModel private[ml] ( - override val uid: String, + @Since("1.5.0") override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - override val numFeatures: Int, - override val numClasses: Int) + @Since("1.6.0") override val numFeatures: Int, + @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -141,11 +161,13 @@ final class RandomForestClassificationModel private[ml] ( numClasses: Int) = this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) + @Since("1.4.0") override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { @@ -186,11 +208,13 @@ final class RandomForestClassificationModel private[ml] ( } } + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) .setParent(parent) } + @Since("1.4.0") override def toString: String = { s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" } From 4a39b5a1bee28cec792d509654f6236390cafdcb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 7 Dec 2015 23:50:57 -0800 Subject: [PATCH 1690/1824] [SPARK-11958][SPARK-11957][ML][DOC] SQLTransformer user guide and example code Add ```SQLTransformer``` user guide, example code and make Scala API doc more clear. Author: Yanbo Liang Closes #10006 from yanboliang/spark-11958. --- docs/ml-features.md | 59 +++++++++++++++++++ .../ml/JavaSQLTransformerExample.java | 59 +++++++++++++++++++ .../src/main/python/ml/sql_transformer.py | 40 +++++++++++++ .../examples/ml/SQLTransformerExample.scala | 45 ++++++++++++++ .../spark/ml/feature/SQLTransformer.scala | 11 +++- 5 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java create mode 100644 examples/src/main/python/ml/sql_transformer.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 5105a948fec8..f85e0d56d2e4 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -756,6 +756,65 @@ for more details on the API. +## SQLTransformer + +`SQLTransformer` implements the transformations which are defined by SQL statement. +Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +where `"__THIS__"` represents the underlying table of the input dataset. +The select clause specifies the fields, constants, and expressions to display in +the output, it can be any select clause that Spark SQL supports. Users can also +use Spark SQL built-in function and UDFs to operate on these selected columns. +For example, `SQLTransformer` supports statements like: + +* `SELECT a, a + b AS a_b FROM __THIS__` +* `SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5` +* `SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b` + +**Examples** + +Assume that we have the following DataFrame with columns `id`, `v1` and `v2`: + +~~~~ + id | v1 | v2 +----|-----|----- + 0 | 1.0 | 3.0 + 2 | 2.0 | 5.0 +~~~~ + +This is the output of the `SQLTransformer` with statement `"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"`: + +~~~~ + id | v1 | v2 | v3 | v4 +----|-----|-----|-----|----- + 0 | 1.0 | 3.0 | 4.0 | 3.0 + 2 | 2.0 | 5.0 | 7.0 |10.0 +~~~~ + +
    +
    + +Refer to the [SQLTransformer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.SQLTransformer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/SQLTransformerExample.scala %} +
    + +
    + +Refer to the [SQLTransformer Java docs](api/java/org/apache/spark/ml/feature/SQLTransformer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java %} +
    + +
    + +Refer to the [SQLTransformer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.SQLTransformer) for more details on the API. + +{% include_example python/ml/sql_transformer.py %} +
    +
    + ## VectorAssembler `VectorAssembler` is a transformer that combines a given list of columns into a single vector diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java new file mode 100644 index 000000000000..d55c70796a96 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.SQLTransformer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaSQLTransformerExample { + public static void main(String[] args) { + + SparkConf conf = new SparkConf().setAppName("JavaSQLTransformerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 1.0, 3.0), + RowFactory.create(2, 2.0, 5.0) + )); + StructType schema = new StructType(new StructField [] { + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("v1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("v2", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + SQLTransformer sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"); + + sqlTrans.transform(df).show(); + // $example off$ + } +} diff --git a/examples/src/main/python/ml/sql_transformer.py b/examples/src/main/python/ml/sql_transformer.py new file mode 100644 index 000000000000..9575d728d815 --- /dev/null +++ b/examples/src/main/python/ml/sql_transformer.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import SQLTransformer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="SQLTransformerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, 1.0, 3.0), + (2, 2.0, 5.0) + ], ["id", "v1", "v2"]) + sqlTrans = SQLTransformer( + statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + sqlTrans.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala new file mode 100644 index 000000000000..014abd1fdbc6 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.SQLTransformer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + + +object SQLTransformerExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SQLTransformerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + + val sqlTrans = new SQLTransformer().setStatement( + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + + sqlTrans.transform(df).show() + // $example off$ + } +} +// scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 3a735017ba83..c09f4d076c96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -27,9 +27,16 @@ import org.apache.spark.sql.types.StructType /** * :: Experimental :: - * Implements the transforms which are defined by SQL statement. - * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * Implements the transformations which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...' * where '__THIS__' represents the underlying table of the input dataset. + * The select clause specifies the fields, constants, and expressions to display in + * the output, it can be any select clause that Spark SQL supports. Users can also + * use Spark SQL built-in function and UDFs to operate on these selected columns. + * For example, [[SQLTransformer]] supports statements like: + * - SELECT a, a + b AS a_b FROM __THIS__ + * - SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5 + * - SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b */ @Experimental @Since("1.6.0") From 48a9804b2ad89b3fb204c79f0dbadbcfea15d8dc Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Tue, 8 Dec 2015 11:02:35 +0000 Subject: [PATCH 1691/1824] =?UTF-8?q?[SPARK-12103][STREAMING][KAFKA][DOC]?= =?UTF-8?q?=20document=20that=20K=20means=20Key=20and=20V=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …means Value Author: cody koeninger Closes #10132 from koeninger/SPARK-12103. --- .../spark/streaming/kafka/KafkaUtils.scala | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index ad2fb8aa5f24..fe572220528d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -51,6 +51,7 @@ object KafkaUtils { * in its own thread * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( ssc: StreamingContext, @@ -74,6 +75,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel Storage level to use for storing the received objects + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K: ClassTag, V: ClassTag, U <: Decoder[_]: ClassTag, T <: Decoder[_]: ClassTag]( ssc: StreamingContext, @@ -93,6 +99,7 @@ object KafkaUtils { * @param groupId The group id for this consumer * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -111,6 +118,7 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. + * @return DStream of (Kafka message key, Kafka message value) */ def createStream( jssc: JavaStreamingContext, @@ -135,6 +143,11 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread * @param storageLevel RDD storage level. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam U type of Kafka message key decoder + * @tparam T type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createStream[K, V, U <: Decoder[_], T <: Decoder[_]]( jssc: JavaStreamingContext, @@ -219,6 +232,11 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ def createRDD[ K: ClassTag, @@ -251,6 +269,12 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ def createRDD[ K: ClassTag, @@ -288,6 +312,15 @@ object KafkaUtils { * host1:port1,host2:port2 form. * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition + * @param keyClass type of Kafka message key + * @param valueClass type of Kafka message value + * @param keyDecoderClass type of Kafka message key decoder + * @param valueDecoderClass type of Kafka message value decoder + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return RDD of (Kafka message key, Kafka message value) */ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jsc: JavaSparkContext, @@ -321,6 +354,12 @@ object KafkaUtils { * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map, * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return RDD of R */ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jsc: JavaSparkContext, @@ -373,6 +412,12 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ def createDirectStream[ K: ClassTag, @@ -419,6 +464,11 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createDirectStream[ K: ClassTag, @@ -470,6 +520,12 @@ object KafkaUtils { * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive) * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @tparam R type returned by messageHandler + * @return DStream of R */ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jssc: JavaStreamingContext, @@ -529,6 +585,11 @@ object KafkaUtils { * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest" * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + * @tparam KD type of Kafka message key decoder + * @tparam VD type of Kafka message value decoder + * @return DStream of (Kafka message key, Kafka message value) */ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jssc: JavaStreamingContext, From 708129187a460aca30790281e9221c0cd5e271df Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 8 Dec 2015 11:05:06 +0000 Subject: [PATCH 1692/1824] [SPARK-12166][TEST] Unset hadoop related environment in testing Author: Jeff Zhang Closes #10172 from zjffdu/SPARK-12166. --- bin/spark-class | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bin/spark-class b/bin/spark-class index 87d06693af4f..5d964ba96abd 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -71,6 +71,12 @@ fi export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" +# For tests +if [[ -n "$SPARK_TESTING" ]]; then + unset YARN_CONF_DIR + unset HADOOP_CONF_DIR +fi + # The launcher library will print arguments separated by a NULL character, to allow arguments with # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating # an array that will be used to exec the final command. From 037b7e76a7f8b59e031873a768d81417dd180472 Mon Sep 17 00:00:00 2001 From: Nakul Jindal Date: Tue, 8 Dec 2015 11:08:27 +0000 Subject: [PATCH 1693/1824] [SPARK-11439][ML] Optimization of creating sparse feature without dense one Sparse feature generated in LinearDataGenerator does not create dense vectors as an intermediate any more. Author: Nakul Jindal Closes #9756 from nakul02/SPARK-11439_sparse_without_creating_dense_feature. --- .../mllib/util/LinearDataGenerator.scala | 44 ++-- .../evaluation/RegressionEvaluatorSuite.scala | 6 +- .../ml/regression/LinearRegressionSuite.scala | 214 ++++++++++-------- 3 files changed, 142 insertions(+), 122 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 6ff07eed6cfd..094528e2ece0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -24,7 +24,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{BLAS, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -131,35 +131,27 @@ object LinearDataGenerator { eps: Double, sparsity: Double): Seq[LabeledPoint] = { require(0.0 <= sparsity && sparsity <= 1.0) - val rnd = new Random(seed) - val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(rnd.nextDouble())) - - val sparseRnd = new Random(seed) - x.foreach { v => - var i = 0 - val len = v.length - while (i < len) { - if (sparseRnd.nextDouble() < sparsity) { - v(i) = 0.0 - } else { - v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) - } - i += 1 - } - } - val y = x.map { xi => - blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() - } + val rnd = new Random(seed) + def rndElement(i: Int) = {(rnd.nextDouble() - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)} - y.zip(x).map { p => - if (sparsity == 0.0) { + if (sparsity == 0.0) { + (0 until nPoints).map { _ => + val features = Vectors.dense(weights.indices.map { rndElement(_) }.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() // Return LabeledPoints with DenseVector - LabeledPoint(p._1, Vectors.dense(p._2)) - } else { + LabeledPoint(label, features) + } + } else { + (0 until nPoints).map { _ => + val indices = weights.indices.filter { _ => rnd.nextDouble() <= sparsity} + val values = indices.map { rndElement(_) } + val features = Vectors.sparse(weights.length, indices.toArray, values.toArray) + val label = BLAS.dot(Vectors.dense(weights), features) + + intercept + eps * rnd.nextGaussian() // Return LabeledPoints with SparseVector - LabeledPoint(p._1, Vectors.dense(p._2).toSparse) + LabeledPoint(label, features) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 60886bf77d2f..954d3bedc14b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -65,15 +65,15 @@ class RegressionEvaluatorSuite // default = rmse val evaluator = new RegressionEvaluator() - assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.1013829 absTol 0.01) // r2 score evaluator.setMetricName("r2") - assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.9998387 absTol 0.01) // mae evaluator.setMetricName("mae") - assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.08399089 absTol 0.01) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2bdc0e184d73..2f3e703f4c25 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -42,6 +42,7 @@ class LinearRegressionSuite In `LinearRegressionSuite`, we will make sure that the model trained by SparkML is the same as the one trained by R's glmnet package. The following instruction describes how to reproduce the data in R. + In a spark-shell, use the following code: import org.apache.spark.mllib.util.LinearDataGenerator val data = @@ -184,15 +185,15 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.995908 - as.numeric.data.V3. 5.275131 + as.numeric.data.V2. 6.973403 + as.numeric.data.V3. 5.284370 */ - val coefficientsR = Vectors.dense(6.995908, 5.275131) + val coefficientsR = Vectors.dense(6.973403, 5.284370) - assert(model1.intercept ~== 0 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR relTol 1E-3) - assert(model2.intercept ~== 0 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR relTol 1E-3) + assert(model1.intercept ~== 0 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR relTol 1E-2) + assert(model2.intercept ~== 0 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR relTol 1E-2) /* Then again with the data with no intercept: @@ -235,14 +236,14 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.24300 - as.numeric.data.V2. 4.024821 - as.numeric.data.V3. 6.679841 + (Intercept) 6.242284 + as.numeric.d1.V2. 4.019605 + as.numeric.d1.V3. 6.679538 */ - val interceptR1 = 6.24300 - val coefficientsR1 = Vectors.dense(4.024821, 6.679841) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + val interceptR1 = 6.242284 + val coefficientsR1 = Vectors.dense(4.019605, 6.679538) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, @@ -296,14 +297,14 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.299752 - as.numeric.data.V3. 4.772913 + as.numeric.data.V2. 6.272927 + as.numeric.data.V3. 4.782604 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(6.299752, 4.772913) + val coefficientsR1 = Vectors.dense(6.272927, 4.782604) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, @@ -312,14 +313,14 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 6.232193 - as.numeric.data.V3. 4.764229 + as.numeric.data.V2. 6.207817 + as.numeric.data.V3. 4.775780 */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(6.232193, 4.764229) + val coefficientsR2 = Vectors.dense(6.207817, 4.775780) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -347,15 +348,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 5.269376 - as.numeric.data.V2. 3.736216 - as.numeric.data.V3. 5.712356) + (Intercept) 5.260103 + as.numeric.d1.V2. 3.725522 + as.numeric.d1.V3. 5.711203 */ - val interceptR1 = 5.269376 - val coefficientsR1 = Vectors.dense(3.736216, 5.712356) + val interceptR1 = 5.260103 + val coefficientsR1 = Vectors.dense(3.725522, 5.711203) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, @@ -363,15 +364,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 5.791109 - as.numeric.data.V2. 3.435466 - as.numeric.data.V3. 5.910406 + (Intercept) 5.790885 + as.numeric.d1.V2. 3.432373 + as.numeric.d1.V3. 5.919196 */ - val interceptR2 = 5.791109 - val coefficientsR2 = Vectors.dense(3.435466, 5.910406) + val interceptR2 = 5.790885 + val coefficientsR2 = Vectors.dense(3.432373, 5.919196) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -398,15 +399,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - as.numeric.data.V2. 5.522875 - as.numeric.data.V3. 4.214502 + (Intercept) . + as.numeric.d1.V2. 5.493430 + as.numeric.d1.V3. 4.223082 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.522875, 4.214502) + val coefficientsR1 = Vectors.dense(5.493430, 4.223082) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, @@ -415,14 +416,14 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 5.263704 - as.numeric.data.V3. 4.187419 + as.numeric.d1.V2. 5.244324 + as.numeric.d1.V3. 4.203106 */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.263704, 4.187419) + val coefficientsR2 = Vectors.dense(5.244324, 4.203106) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction").collect().foreach { case Row(features: DenseVector, prediction1: Double) => @@ -457,15 +458,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.324108 - as.numeric.data.V2. 3.168435 - as.numeric.data.V3. 5.200403 + (Intercept) 5.689855 + as.numeric.d1.V2. 3.661181 + as.numeric.d1.V3. 6.000274 */ - val interceptR1 = 5.696056 - val coefficientsR1 = Vectors.dense(3.670489, 6.001122) + val interceptR1 = 5.689855 + val coefficientsR1 = Vectors.dense(3.661181, 6.000274) - assert(model1.intercept ~== interceptR1 relTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 relTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6 @@ -473,15 +474,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) 6.114723 - as.numeric.data.V2. 3.409937 - as.numeric.data.V3. 6.146531 + (Intercept) 6.113890 + as.numeric.d1.V2. 3.407021 + as.numeric.d1.V3. 6.152512 */ - val interceptR2 = 6.114723 - val coefficientsR2 = Vectors.dense(3.409937, 6.146531) + val interceptR2 = 6.113890 + val coefficientsR2 = Vectors.dense(3.407021, 6.152512) - assert(model2.intercept ~== interceptR2 relTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 relTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -518,15 +519,15 @@ class LinearRegressionSuite > coefficients 3 x 1 sparse Matrix of class "dgCMatrix" s0 - (Intercept) . - as.numeric.dataM.V2. 5.673348 - as.numeric.dataM.V3. 4.322251 + (Intercept) . + as.numeric.d1.V2. 5.643748 + as.numeric.d1.V3. 4.331519 */ val interceptR1 = 0.0 - val coefficientsR1 = Vectors.dense(5.673348, 4.322251) + val coefficientsR1 = Vectors.dense(5.643748, 4.331519) - assert(model1.intercept ~== interceptR1 absTol 1E-3) - assert(model1.coefficients ~= coefficientsR1 relTol 1E-3) + assert(model1.intercept ~== interceptR1 absTol 1E-2) + assert(model1.coefficients ~= coefficientsR1 relTol 1E-2) /* coefficients <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, @@ -535,14 +536,15 @@ class LinearRegressionSuite 3 x 1 sparse Matrix of class "dgCMatrix" s0 (Intercept) . - as.numeric.data.V2. 5.477988 - as.numeric.data.V3. 4.297622 + as.numeric.d1.V2. 5.455902 + as.numeric.d1.V3. 4.312266 + */ val interceptR2 = 0.0 - val coefficientsR2 = Vectors.dense(5.477988, 4.297622) + val coefficientsR2 = Vectors.dense(5.455902, 4.312266) - assert(model2.intercept ~== interceptR2 absTol 1E-3) - assert(model2.coefficients ~= coefficientsR2 relTol 1E-3) + assert(model2.intercept ~== interceptR2 absTol 1E-2) + assert(model2.coefficients ~= coefficientsR2 relTol 1E-2) model1.transform(datasetWithDenseFeature).select("features", "prediction") .collect().foreach { @@ -592,21 +594,47 @@ class LinearRegressionSuite } /* - Use the following R code to generate model training results. - - predictions <- predict(fit, newx=features) - residuals <- label - predictions - > mean(residuals^2) # MSE - [1] 0.009720325 - > mean(abs(residuals)) # MAD - [1] 0.07863206 - > cor(predictions, label)^2# r^2 - [,1] - s0 0.9998749 + # Use the following R code to generate model training results. + + # path/part-00000 is the file generated by running LinearDataGenerator.generateLinearInput + # as described before the beforeAll() method. + d1 <- read.csv("path/part-00000", header=FALSE, stringsAsFactors=FALSE) + fit <- glm(V1 ~ V2 + V3, data = d1, family = "gaussian") + names(f1)[1] = c("V2") + names(f1)[2] = c("V3") + f1 <- data.frame(as.numeric(d1$V2), as.numeric(d1$V3)) + predictions <- predict(fit, newdata=f1) + l1 <- as.numeric(d1$V1) + + residuals <- l1 - predictions + > mean(residuals^2) # MSE + [1] 0.00985449 + > mean(abs(residuals)) # MAD + [1] 0.07961668 + > cor(predictions, l1)^2 # r^2 + [1] 0.9998737 + + > summary(fit) + + Call: + glm(formula = V1 ~ V2 + V3, family = "gaussian", data = d1) + + Deviance Residuals: + Min 1Q Median 3Q Max + -0.47082 -0.06797 0.00002 0.06725 0.34635 + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 6.3022157 0.0018600 3388 <2e-16 *** + V2 4.6982442 0.0011805 3980 <2e-16 *** + V3 7.1994344 0.0009044 7961 <2e-16 *** + --- + + .... */ - assert(model.summary.meanSquaredError ~== 0.00972035 relTol 1E-5) - assert(model.summary.meanAbsoluteError ~== 0.07863206 relTol 1E-5) - assert(model.summary.r2 ~== 0.9998749 relTol 1E-5) + assert(model.summary.meanSquaredError ~== 0.00985449 relTol 1E-4) + assert(model.summary.meanAbsoluteError ~== 0.07961668 relTol 1E-4) + assert(model.summary.r2 ~== 0.9998737 relTol 1E-4) // Normal solver uses "WeightedLeastSquares". This algorithm does not generate // objective history because it does not run through iterations. @@ -621,14 +649,14 @@ class LinearRegressionSuite // To clalify that the normal solver is used here. assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) - val devianceResidualsR = Array(-0.35566, 0.34504) - val seCoefR = Array(0.0011756, 0.0009032, 0.0018489) - val tValsR = Array(3998, 7971, 3407) + val devianceResidualsR = Array(-0.47082, 0.34635) + val seCoefR = Array(0.0011805, 0.0009044, 0.0018600) + val tValsR = Array(3980, 7961, 3388) val pValsR = Array(0, 0, 0) model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x => - assert(x._1 ~== x._2 absTol 1E-5) } + assert(x._1 ~== x._2 absTol 1E-4) } model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x => - assert(x._1 ~== x._2 absTol 1E-5) } + assert(x._1 ~== x._2 absTol 1E-4) } model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) } model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) } } From da2012a0e152aa078bdd19a5c7f91786a2dd7016 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 8 Dec 2015 19:18:59 +0800 Subject: [PATCH 1694/1824] [SPARK-11551][DOC][EXAMPLE] Revert PR #10002 This reverts PR #10002, commit 78209b0ccaf3f22b5e2345dfb2b98edfdb746819. The original PR wasn't tested on Jenkins before being merged. Author: Cheng Lian Closes #10200 from liancheng/revert-pr-10002. --- docs/ml-features.md | 1109 ++++++++++++++++- .../examples/ml/JavaBinarizerExample.java | 68 - .../examples/ml/JavaBucketizerExample.java | 70 -- .../spark/examples/ml/JavaDCTExample.java | 65 - .../ml/JavaElementwiseProductExample.java | 75 -- .../examples/ml/JavaMinMaxScalerExample.java | 50 - .../spark/examples/ml/JavaNGramExample.java | 71 -- .../examples/ml/JavaNormalizerExample.java | 52 - .../examples/ml/JavaOneHotEncoderExample.java | 77 -- .../spark/examples/ml/JavaPCAExample.java | 71 -- .../ml/JavaPolynomialExpansionExample.java | 71 -- .../examples/ml/JavaRFormulaExample.java | 69 - .../ml/JavaStandardScalerExample.java | 53 - .../ml/JavaStopWordsRemoverExample.java | 65 - .../examples/ml/JavaStringIndexerExample.java | 66 - .../examples/ml/JavaTokenizerExample.java | 75 -- .../ml/JavaVectorAssemblerExample.java | 67 - .../examples/ml/JavaVectorIndexerExample.java | 60 - .../examples/ml/JavaVectorSlicerExample.java | 73 -- .../src/main/python/ml/binarizer_example.py | 43 - .../src/main/python/ml/bucketizer_example.py | 42 - .../python/ml/elementwise_product_example.py | 39 - examples/src/main/python/ml/n_gram_example.py | 42 - .../src/main/python/ml/normalizer_example.py | 41 - .../main/python/ml/onehot_encoder_example.py | 47 - examples/src/main/python/ml/pca_example.py | 42 - .../python/ml/polynomial_expansion_example.py | 43 - .../src/main/python/ml/rformula_example.py | 44 - .../main/python/ml/standard_scaler_example.py | 42 - .../python/ml/stopwords_remover_example.py | 40 - .../main/python/ml/string_indexer_example.py | 39 - .../src/main/python/ml/tokenizer_example.py | 44 - .../python/ml/vector_assembler_example.py | 42 - .../main/python/ml/vector_indexer_example.py | 39 - .../spark/examples/ml/BinarizerExample.scala | 48 - .../spark/examples/ml/BucketizerExample.scala | 51 - .../apache/spark/examples/ml/DCTExample.scala | 54 - .../ml/ElementWiseProductExample.scala | 53 - .../examples/ml/MinMaxScalerExample.scala | 49 - .../spark/examples/ml/NGramExample.scala | 47 - .../spark/examples/ml/NormalizerExample.scala | 50 - .../examples/ml/OneHotEncoderExample.scala | 58 - .../apache/spark/examples/ml/PCAExample.scala | 54 - .../ml/PolynomialExpansionExample.scala | 53 - .../spark/examples/ml/RFormulaExample.scala | 49 - .../examples/ml/StandardScalerExample.scala | 51 - .../examples/ml/StopWordsRemoverExample.scala | 48 - .../examples/ml/StringIndexerExample.scala | 49 - .../spark/examples/ml/TokenizerExample.scala | 54 - .../examples/ml/VectorAssemblerExample.scala | 49 - .../examples/ml/VectorIndexerExample.scala | 53 - .../examples/ml/VectorSlicerExample.scala | 58 - 52 files changed, 1058 insertions(+), 2806 deletions(-) delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java delete mode 100644 examples/src/main/python/ml/binarizer_example.py delete mode 100644 examples/src/main/python/ml/bucketizer_example.py delete mode 100644 examples/src/main/python/ml/elementwise_product_example.py delete mode 100644 examples/src/main/python/ml/n_gram_example.py delete mode 100644 examples/src/main/python/ml/normalizer_example.py delete mode 100644 examples/src/main/python/ml/onehot_encoder_example.py delete mode 100644 examples/src/main/python/ml/pca_example.py delete mode 100644 examples/src/main/python/ml/polynomial_expansion_example.py delete mode 100644 examples/src/main/python/ml/rformula_example.py delete mode 100644 examples/src/main/python/ml/standard_scaler_example.py delete mode 100644 examples/src/main/python/ml/stopwords_remover_example.py delete mode 100644 examples/src/main/python/ml/string_indexer_example.py delete mode 100644 examples/src/main/python/ml/tokenizer_example.py delete mode 100644 examples/src/main/python/ml/vector_assembler_example.py delete mode 100644 examples/src/main/python/ml/vector_indexer_example.py delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index f85e0d56d2e4..01d6abeb5ba6 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -170,7 +170,25 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} + +val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") +)).toDF("label", "sentence") +val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") +val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + +val tokenized = tokenizer.transform(sentenceDataFrame) +tokenized.select("words", "label").take(3).foreach(println) +val regexTokenized = regexTokenizer.transform(sentenceDataFrame) +regexTokenized.select("words", "label").take(3).foreach(println) +{% endhighlight %}
    @@ -179,7 +197,44 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) +}); +DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); +Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); +DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); +for (Row r : wordsDataFrame.select("words", "label").take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); +} + +RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); +{% endhighlight %}
    @@ -188,7 +243,21 @@ Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.featu the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% include_example python/ml/tokenizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Tokenizer, RegexTokenizer + +sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") +], ["label", "sentence"]) +tokenizer = Tokenizer(inputCol="sentence", outputCol="words") +wordsDataFrame = tokenizer.transform(sentenceDataFrame) +for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) +regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") +# alternatively, pattern="\\w+", gaps(False) +{% endhighlight %}
    @@ -237,7 +306,19 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.StopWordsRemover + +val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") +val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) +)).toDF("id", "raw") + +remover.transform(dataSet).show() +{% endhighlight %}
    @@ -245,7 +326,34 @@ for more details on the API. Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) +)); +StructType schema = new StructType(new StructField[] { + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame dataset = jsql.createDataFrame(rdd, schema); + +remover.transform(dataset).show(); +{% endhighlight %}
    @@ -253,7 +361,17 @@ for more details on the API. Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% include_example python/ml/stopwords_remover_example.py %} +{% highlight python %} +from pyspark.ml.feature import StopWordsRemover + +sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) +], ["label", "raw"]) + +remover = StopWordsRemover(inputCol="raw", outputCol="filtered") +remover.transform(sentenceData).show(truncate=False) +{% endhighlight %}
    @@ -270,7 +388,19 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.NGram + +val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) +)).toDF("label", "words") + +val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") +val ngramDataFrame = ngram.transform(wordDataFrame) +ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) +{% endhighlight %}
    @@ -278,7 +408,38 @@ for more details on the API. Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); +NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); +DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); +for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); +} +{% endhighlight %}
    @@ -286,7 +447,19 @@ for more details on the API. Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% include_example python/ml/n_gram_example.py %} +{% highlight python %} +from pyspark.ml.feature import NGram + +wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) +], ["label", "words"]) +ngram = NGram(inputCol="words", outputCol="ngrams") +ngramDataFrame = ngram.transform(wordDataFrame) +for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) +{% endhighlight %}
    @@ -303,7 +476,26 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.Binarizer +import org.apache.spark.sql.DataFrame + +val data = Array( + (0, 0.1), + (1, 0.8), + (2, 0.2) +) +val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + +val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + +val binarizedDataFrame = binarizer.transform(dataFrame) +val binarizedFeatures = binarizedDataFrame.select("binarized_feature") +binarizedFeatures.collect().foreach(println) +{% endhighlight %}
    @@ -311,7 +503,40 @@ for more details on the API. Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) +)); +StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) +}); +DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); +Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); +DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); +DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); +for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); +} +{% endhighlight %}
    @@ -319,7 +544,20 @@ for more details on the API. Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% include_example python/ml/binarizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Binarizer + +continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) +], ["label", "feature"]) +binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") +binarizedDataFrame = binarizer.transform(continuousDataFrame) +binarizedFeatures = binarizedDataFrame.select("binarized_feature") +for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) +{% endhighlight %}
    @@ -333,7 +571,25 @@ for more details on the API. Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors + +val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) +) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) +val pcaDF = pca.transform(df) +val result = pcaDF.select("pcaFeatures") +result.show() +{% endhighlight %}
    @@ -341,7 +597,42 @@ for more details on the API. Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.PCA +import org.apache.spark.ml.feature.PCAModel +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaSparkContext jsc = ... +SQLContext jsql = ... +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); +DataFrame result = pca.transform(df).select("pcaFeatures"); +result.show(); +{% endhighlight %}
    @@ -349,7 +640,19 @@ for more details on the API. Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% include_example python/ml/pca_example.py %} +{% highlight python %} +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] +df = sqlContext.createDataFrame(data,["features"]) +pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") +model = pca.fit(df) +result = model.transform(df).select("pcaFeatures") +result.show(truncate=False) +{% endhighlight %}
    @@ -363,7 +666,23 @@ for more details on the API. Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors + +val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) +) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) +val polyDF = polynomialExpansion.transform(df) +polyDF.select("polyFeatures").take(3).foreach(println) +{% endhighlight %}
    @@ -371,7 +690,43 @@ for more details on the API. Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaSparkContext jsc = ... +SQLContext jsql = ... +PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +DataFrame polyDF = polyExpansion.transform(df); +Row[] row = polyDF.select("polyFeatures").take(3); +for (Row r : row) { + System.out.println(r.get(0)); +} +{% endhighlight %}
    @@ -379,7 +734,20 @@ for more details on the API. Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% include_example python/ml/polynomial_expansion_example.py %} +{% highlight python %} +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors + +df = sqlContext.createDataFrame( + [(Vectors.dense([-2.0, 2.3]), ), + (Vectors.dense([0.0, 0.0]), ), + (Vectors.dense([0.6, -1.1]), )], + ["features"]) +px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") +polyDF = px.transform(df) +for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) +{% endhighlight %}
    @@ -403,7 +771,22 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors + +val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) +val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") +val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) +val dctDf = dct.transform(df) +dctDf.select("featuresDCT").show(3) +{% endhighlight %}
    @@ -411,7 +794,39 @@ for more details on the API. Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), +}); +DataFrame df = jsql.createDataFrame(data, schema); +DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); +DataFrame dctDf = dct.transform(df); +dctDf.select("featuresDCT").show(3); +{% endhighlight %}
    @@ -466,7 +881,18 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.StringIndexer + +val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) +).toDF("id", "category") +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") +val indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %}
    @@ -474,7 +900,37 @@ for more details on the API. Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[] { + createStructField("id", DoubleType, false), + createStructField("category", StringType, false) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); +DataFrame indexed = indexer.fit(df).transform(df); +indexed.show(); +{% endhighlight %}
    @@ -482,7 +938,16 @@ for more details on the API. Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% include_example python/ml/string_indexer_example.py %} +{% highlight python %} +from pyspark.ml.feature import StringIndexer + +df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) +indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %}
    @@ -496,7 +961,29 @@ for more details on the API. Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} + +val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") +)).toDF("id", "category") + +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) +val indexed = indexer.transform(df) + +val encoder = new OneHotEncoder().setInputCol("categoryIndex"). + setOutputCol("categoryVec") +val encoded = encoder.transform(indexed) +encoded.select("id", "categoryVec").foreach(println) +{% endhighlight %}
    @@ -504,7 +991,45 @@ for more details on the API. Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); +DataFrame indexed = indexer.transform(df); + +OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); +DataFrame encoded = encoder.transform(indexed); +{% endhighlight %}
    @@ -512,7 +1037,24 @@ for more details on the API. Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% include_example python/ml/onehot_encoder_example.py %} +{% highlight python %} +from pyspark.ml.feature import OneHotEncoder, StringIndexer + +df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") +], ["id", "category"]) + +stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +model = stringIndexer.fit(df) +indexed = model.transform(df) +encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") +encoded = encoder.transform(indexed) +{% endhighlight %}
    @@ -536,7 +1078,23 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.VectorIndexer + +val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) +val indexerModel = indexer.fit(data) +val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet +println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + +// Create new column "indexed" with categorical values transformed to indices +val indexedData = indexerModel.transform(data) +{% endhighlight %}
    @@ -544,7 +1102,30 @@ for more details on the API. Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %} +{% highlight java %} +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame data = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); +VectorIndexerModel indexerModel = indexer.fit(data); +Map> categoryMaps = indexerModel.javaCategoryMaps(); +System.out.print("Chose " + categoryMaps.size() + "categorical features:"); +for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); +} +System.out.println(); + +// Create new column "indexed" with categorical values transformed to indices +DataFrame indexedData = indexerModel.transform(data); +{% endhighlight %}
    @@ -552,7 +1133,17 @@ for more details on the API. Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% include_example python/ml/vector_indexer_example.py %} +{% highlight python %} +from pyspark.ml.feature import VectorIndexer + +data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) +indexerModel = indexer.fit(data) + +# Create new column "indexed" with categorical values transformed to indices +indexedData = indexerModel.transform(data) +{% endhighlight %}
    @@ -569,7 +1160,22 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.Normalizer + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") + +// Normalize each Vector using $L^1$ norm. +val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) +val l1NormData = normalizer.transform(dataFrame) + +// Normalize each Vector using $L^\infty$ norm. +val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) +{% endhighlight %}
    @@ -577,7 +1183,24 @@ for more details on the API. Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %} +{% highlight java %} +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + +// Normalize each Vector using $L^1$ norm. +Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); +DataFrame l1NormData = normalizer.transform(dataFrame); + +// Normalize each Vector using $L^\infty$ norm. +DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); +{% endhighlight %}
    @@ -585,7 +1208,19 @@ for more details on the API. Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% include_example python/ml/normalizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Normalizer + +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") + +# Normalize each Vector using $L^1$ norm. +normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) +l1NormData = normalizer.transform(dataFrame) + +# Normalize each Vector using $L^\infty$ norm. +lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) +{% endhighlight %}
    @@ -609,7 +1244,23 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.StandardScaler + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + +// Compute summary statistics by fitting the StandardScaler +val scalerModel = scaler.fit(dataFrame) + +// Normalize each feature to have unit standard deviation. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %}
    @@ -617,7 +1268,25 @@ for more details on the API. Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %} +{% highlight java %} +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + +// Compute summary statistics by fitting the StandardScaler +StandardScalerModel scalerModel = scaler.fit(dataFrame); + +// Normalize each feature to have unit standard deviation. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %}
    @@ -625,7 +1294,20 @@ for more details on the API. Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% include_example python/ml/standard_scaler_example.py %} +{% highlight python %} +from pyspark.ml.feature import StandardScaler + +dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + +# Compute summary statistics by fitting the StandardScaler +scalerModel = scaler.fit(dataFrame) + +# Normalize each feature to have unit standard deviation. +scaledData = scalerModel.transform(dataFrame) +{% endhighlight %}
    @@ -655,7 +1337,21 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.MinMaxScaler + +val dataFrame = sqlContext.read.format("libsvm") + .load("data/mllib/sample_libsvm_data.txt") +val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + +// Compute summary statistics and generate MinMaxScalerModel +val scalerModel = scaler.fit(dataFrame) + +// rescale each feature to range [min, max]. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %}
    @@ -664,7 +1360,24 @@ Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMa and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %} +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; + +DataFrame dataFrame = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); +MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + +// Compute summary statistics and generate MinMaxScalerModel +MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + +// rescale each feature to range [min, max]. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %}
    @@ -688,7 +1401,23 @@ The following example demonstrates how to bucketize a column of `Double`s into a Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.Bucketizer +import org.apache.spark.sql.DataFrame + +val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + +val data = Array(-0.5, -0.3, 0.0, 0.2) +val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + +val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + +// Transform original data into its bucket index. +val bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %}
    @@ -696,7 +1425,38 @@ for more details on the API. Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + +JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) +)); +StructType schema = new StructType(new StructField[] { + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) +}); +DataFrame dataFrame = jsql.createDataFrame(data, schema); + +Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + +// Transform original data into its bucket index. +DataFrame bucketedData = bucketizer.transform(dataFrame); +{% endhighlight %}
    @@ -704,7 +1464,19 @@ for more details on the API. Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% include_example python/ml/bucketizer_example.py %} +{% highlight python %} +from pyspark.ml.feature import Bucketizer + +splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + +data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] +dataFrame = sqlContext.createDataFrame(data, ["features"]) + +bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + +# Transform original data into its bucket index. +bucketedData = bucketizer.transform(dataFrame) +{% endhighlight %}
    @@ -736,7 +1508,25 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors + +// Create some vector data; also works for sparse vectors +val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + +val transformingVector = Vectors.dense(0.0, 1.0, 2.0) +val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + +// Batch transform the vectors to create new column: +transformer.transform(dataFrame).show() + +{% endhighlight %}
    @@ -744,7 +1534,41 @@ for more details on the API. Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +// Create some vector data; also works for sparse vectors +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) +)); +List fields = new ArrayList(2); +fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); +fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); +StructType schema = DataTypes.createStructType(fields); +DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); +// Batch transform the vectors to create new column: +transformer.transform(dataFrame).show(); + +{% endhighlight %}
    @@ -752,8 +1576,19 @@ for more details on the API. Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% include_example python/ml/elementwise_product_example.py %} +{% highlight python %} +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors + +data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] +df = sqlContext.createDataFrame(data, ["vector"]) +transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") +transformer.transform(df).show() + +{% endhighlight %}
    + ## SQLTransformer @@ -856,7 +1691,19 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %} +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.feature.VectorAssembler + +val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) +).toDF("id", "hour", "mobile", "userFeatures", "clicked") +val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") +val output = assembler.transform(dataset) +println(output.select("features", "clicked").first()) +{% endhighlight %}
    @@ -864,7 +1711,36 @@ for more details on the API. Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) +}); +Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); +JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + +DataFrame output = assembler.transform(dataset); +System.out.println(output.select("features", "clicked").first()); +{% endhighlight %}
    @@ -872,7 +1748,19 @@ for more details on the API. Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) for more details on the API. -{% include_example python/ml/vector_assembler_example.py %} +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler + +dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) +assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") +output = assembler.transform(dataset) +print(output.select("features", "clicked").first()) +{% endhighlight %}
    @@ -1002,7 +1890,33 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %} +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0) +) + +val defaultAttr = NumericAttribute.defaultAttr +val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) +val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + +val dataRDD = sc.parallelize(data).map(Row.apply) +val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) + +val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + +slicer.setIndices(1).setNames("f3") +// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + +val output = slicer.transform(dataset) +println(output.select("userFeatures", "features").first()) +{% endhighlight %}
    @@ -1010,7 +1924,41 @@ for more details on the API. Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") +}; +AttributeGroup group = new AttributeGroup("userFeatures", attrs); + +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) +)); + +DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + +VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + +vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); +// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + +DataFrame output = vectorSlicer.transform(dataset); + +System.out.println(output.select("userFeatures", "features").first()); +{% endhighlight %}
    @@ -1047,7 +1995,21 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %} +{% highlight scala %} +import org.apache.spark.ml.feature.RFormula + +val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) +)).toDF("id", "country", "hour", "clicked") +val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") +val output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %}
    @@ -1055,7 +2017,38 @@ for more details on the API. Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %} +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) +}); +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) +)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + +DataFrame output = formula.fit(dataset).transform(dataset); +output.select("features", "label").show(); +{% endhighlight %}
    @@ -1063,7 +2056,21 @@ for more details on the API. Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% include_example python/ml/rformula_example.py %} +{% highlight python %} +from pyspark.ml.feature import RFormula + +dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) +formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") +output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java deleted file mode 100644 index 9698cac50437..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaBinarizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) - )); - StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) - }); - DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); - Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); - DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); - DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); - for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); - } - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java deleted file mode 100644 index b06a23e76d60..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Bucketizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaBucketizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) - )); - StructType schema = new StructType(new StructField[]{ - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) - }); - DataFrame dataFrame = jsql.createDataFrame(data, schema); - - Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - - // Transform original data into its bucket index. - DataFrame bucketedData = bucketizer.transform(dataFrame); - // $example off$ - jsc.stop(); - } -} - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java deleted file mode 100644 index 35c0d534a45e..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaDCTExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) - )); - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - DataFrame df = jsql.createDataFrame(data, schema); - DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); - DataFrame dctDf = dct.transform(df); - dctDf.select("featuresDCT").show(3); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java deleted file mode 100644 index 2898accec61b..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaElementwiseProductExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - // Create some vector data; also works for sparse vectors - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) - )); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); - fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); - - StructType schema = DataTypes.createStructType(fields); - - DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); - - Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); - - ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); - - // Batch transform the vectors to create new column: - transformer.transform(dataFrame).show(); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java deleted file mode 100644 index 138b3ab6aba4..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaMinMaxScalerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); - - // Compute summary statistics and generate MinMaxScalerModel - MinMaxScalerModel scalerModel = scaler.fit(dataFrame); - - // rescale each feature to range [min, max]. - DataFrame scaledData = scalerModel.transform(dataFrame); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java deleted file mode 100644 index 8fd75ed8b5f4..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaNGramExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField( - "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) - }); - - DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); - - NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); - - DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); - - for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); - } - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java deleted file mode 100644 index 6283a355e1fe..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaNormalizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - // Normalize each Vector using $L^1$ norm. - Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); - - DataFrame l1NormData = normalizer.transform(dataFrame); - - // Normalize each Vector using $L^\infty$ norm. - DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java deleted file mode 100644 index 172a9cc6feb2..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaOneHotEncoderExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) - }); - - DataFrame df = sqlContext.createDataFrame(jrdd, schema); - - StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); - DataFrame indexed = indexer.transform(df); - - OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); - DataFrame encoded = encoder.transform(indexed); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java deleted file mode 100644 index 8282fab084f3..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.PCA; -import org.apache.spark.ml.feature.PCAModel; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaPCAExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - - DataFrame df = jsql.createDataFrame(data, schema); - - PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); - - DataFrame result = pca.transform(df).select("pcaFeatures"); - result.show(); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java deleted file mode 100644 index 668f71e64056..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.PolynomialExpansion; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaPolynomialExpansionExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); - - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - - DataFrame df = jsql.createDataFrame(data, schema); - DataFrame polyDF = polyExpansion.transform(df); - - Row[] row = polyDF.select("polyFeatures").take(3); - for (Row r : row) { - System.out.println(r.get(0)); - } - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java deleted file mode 100644 index 1e1062b541ad..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.types.DataTypes.*; -// $example off$ - -public class JavaRFormulaExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - StructType schema = createStructType(new StructField[]{ - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) - }); - - JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) - )); - - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - DataFrame output = formula.fit(dataset).transform(dataset); - output.select("features", "label").show(); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java deleted file mode 100644 index 0cbdc97e8ae3..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaStandardScalerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - - // Compute summary statistics by fitting the StandardScaler - StandardScalerModel scalerModel = scaler.fit(dataFrame); - - // Normalize each feature to have unit standard deviation. - DataFrame scaledData = scalerModel.transform(dataFrame); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java deleted file mode 100644 index b6b201c6b68d..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaStopWordsRemoverExample { - - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - - JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField( - "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) - }); - - DataFrame dataset = jsql.createDataFrame(rdd, schema); - remover.transform(dataset).show(); - // $example off$ - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java deleted file mode 100644 index 05d12c1e702f..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.types.DataTypes.*; -// $example off$ - -public class JavaStringIndexerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") - )); - StructType schema = new StructType(new StructField[]{ - createStructField("id", IntegerType, false), - createStructField("category", StringType, false) - }); - DataFrame df = sqlContext.createDataFrame(jrdd, schema); - StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); - DataFrame indexed = indexer.fit(df).transform(df); - indexed.show(); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java deleted file mode 100644 index 617dc3f66e3b..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -// $example off$ - -public class JavaTokenizerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) - }); - - DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); - - Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - - DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); - for (Row r : wordsDataFrame.select("words", "label"). take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); - } - - RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); - // $example off$ - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java deleted file mode 100644 index 7e230b5897c1..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.VectorAssembler; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; - -import static org.apache.spark.sql.types.DataTypes.*; -// $example off$ - -public class JavaVectorAssemblerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // $example on$ - StructType schema = createStructType(new StructField[]{ - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) - }); - Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); - JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - - VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - - DataFrame output = assembler.transform(dataset); - System.out.println(output.select("features", "clicked").first()); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java deleted file mode 100644 index 06b4bf6bf8ff..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; -// $example off$ - -public class JavaVectorIndexerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); - VectorIndexerModel indexerModel = indexer.fit(data); - - Map> categoryMaps = indexerModel.javaCategoryMaps(); - System.out.print("Chose " + categoryMaps.size() + " categorical features:"); - - for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); - } - System.out.println(); - - // Create new column "indexed" with categorical values transformed to indices - DataFrame indexedData = indexerModel.transform(data); - // $example off$ - jsc.stop(); - } -} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java deleted file mode 100644 index 4d5cb04ff5e2..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; - -// $example on$ -import com.google.common.collect.Lists; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.attribute.Attribute; -import org.apache.spark.ml.attribute.AttributeGroup; -import org.apache.spark.ml.attribute.NumericAttribute; -import org.apache.spark.ml.feature.VectorSlicer; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -// $example off$ - -public class JavaVectorSlicerExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // $example on$ - Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") - }; - AttributeGroup group = new AttributeGroup("userFeatures", attrs); - - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) - )); - - DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); - - VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); - - vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); - // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - - DataFrame output = vectorSlicer.transform(dataset); - - System.out.println(output.select("userFeatures", "features").first()); - // $example off$ - jsc.stop(); - } -} - diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py deleted file mode 100644 index 960ad208be12..000000000000 --- a/examples/src/main/python/ml/binarizer_example.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Binarizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="BinarizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) - ], ["label", "feature"]) - binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") - binarizedDataFrame = binarizer.transform(continuousDataFrame) - binarizedFeatures = binarizedDataFrame.select("binarized_feature") - for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py deleted file mode 100644 index a12750aa9248..000000000000 --- a/examples/src/main/python/ml/bucketizer_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Bucketizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="BucketizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - - data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] - dataFrame = sqlContext.createDataFrame(data, ["features"]) - - bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - - # Transform original data into its bucket index. - bucketedData = bucketizer.transform(dataFrame) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py deleted file mode 100644 index c85cb0d89543..000000000000 --- a/examples/src/main/python/ml/elementwise_product_example.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="ElementwiseProductExample") - sqlContext = SQLContext(sc) - - # $example on$ - data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] - df = sqlContext.createDataFrame(data, ["vector"]) - transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") - transformer.transform(df).show() - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py deleted file mode 100644 index f2d85f53e721..000000000000 --- a/examples/src/main/python/ml/n_gram_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import NGram -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="NGramExample") - sqlContext = SQLContext(sc) - - # $example on$ - wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) - ], ["label", "words"]) - ngram = NGram(inputCol="words", outputCol="ngrams") - ngramDataFrame = ngram.transform(wordDataFrame) - for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py deleted file mode 100644 index 833d93e976a7..000000000000 --- a/examples/src/main/python/ml/normalizer_example.py +++ /dev/null @@ -1,41 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Normalizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="NormalizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Normalize each Vector using $L^1$ norm. - normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) - l1NormData = normalizer.transform(dataFrame) - - # Normalize each Vector using $L^\infty$ norm. - lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py deleted file mode 100644 index 7529dfd09213..000000000000 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ /dev/null @@ -1,47 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import OneHotEncoder, StringIndexer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="OneHotEncoderExample") - sqlContext = SQLContext(sc) - - # $example on$ - df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - ], ["id", "category"]) - - stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - model = stringIndexer.fit(df) - indexed = model.transform(df) - encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") - encoded = encoder.transform(indexed) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py deleted file mode 100644 index 8b66140a40a7..000000000000 --- a/examples/src/main/python/ml/pca_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="PCAExample") - sqlContext = SQLContext(sc) - - # $example on$ - data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] - df = sqlContext.createDataFrame(data,["features"]) - pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") - model = pca.fit(df) - result = model.transform(df).select("pcaFeatures") - result.show(truncate=False) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py deleted file mode 100644 index 030a6132a451..000000000000 --- a/examples/src/main/python/ml/polynomial_expansion_example.py +++ /dev/null @@ -1,43 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="PolynomialExpansionExample") - sqlContext = SQLContext(sc) - - # $example on$ - df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) - px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") - polyDF = px.transform(df) - for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py deleted file mode 100644 index b544a1470076..000000000000 --- a/examples/src/main/python/ml/rformula_example.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import RFormula -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="RFormulaExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) - formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") - output = formula.fit(dataset).transform(dataset) - output.select("features", "label").show() - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py deleted file mode 100644 index 139acecbfb53..000000000000 --- a/examples/src/main/python/ml/standard_scaler_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import StandardScaler -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="StandardScalerExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - - # Compute summary statistics by fitting the StandardScaler - scalerModel = scaler.fit(dataFrame) - - # Normalize each feature to have unit standard deviation. - scaledData = scalerModel.transform(dataFrame) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py deleted file mode 100644 index 01f94af8ca75..000000000000 --- a/examples/src/main/python/ml/stopwords_remover_example.py +++ /dev/null @@ -1,40 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import StopWordsRemover -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="StopWordsRemoverExample") - sqlContext = SQLContext(sc) - - # $example on$ - sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) - ], ["label", "raw"]) - - remover = StopWordsRemover(inputCol="raw", outputCol="filtered") - remover.transform(sentenceData).show(truncate=False) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py deleted file mode 100644 index 58a8cb5d56b7..000000000000 --- a/examples/src/main/python/ml/string_indexer_example.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import StringIndexer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="StringIndexerExample") - sqlContext = SQLContext(sc) - - # $example on$ - df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) - indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - indexed = indexer.fit(df).transform(df) - indexed.show() - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py deleted file mode 100644 index ce9b225be535..000000000000 --- a/examples/src/main/python/ml/tokenizer_example.py +++ /dev/null @@ -1,44 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import Tokenizer, RegexTokenizer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="TokenizerExample") - sqlContext = SQLContext(sc) - - # $example on$ - sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") - ], ["label", "sentence"]) - tokenizer = Tokenizer(inputCol="sentence", outputCol="words") - wordsDataFrame = tokenizer.transform(sentenceDataFrame) - for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) - regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") - # alternatively, pattern="\\w+", gaps(False) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py deleted file mode 100644 index 04f64839f188..000000000000 --- a/examples/src/main/python/ml/vector_assembler_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="VectorAssemblerExample") - sqlContext = SQLContext(sc) - - # $example on$ - dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) - assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") - output = assembler.transform(dataset) - print(output.select("features", "clicked").first()) - # $example off$ - - sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py deleted file mode 100644 index cc00d1454f2e..000000000000 --- a/examples/src/main/python/ml/vector_indexer_example.py +++ /dev/null @@ -1,39 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -from pyspark import SparkContext -from pyspark.sql import SQLContext -# $example on$ -from pyspark.ml.feature import VectorIndexer -# $example off$ - -if __name__ == "__main__": - sc = SparkContext(appName="VectorIndexerExample") - sqlContext = SQLContext(sc) - - # $example on$ - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) - indexerModel = indexer.fit(data) - - # Create new column "indexed" with categorical values transformed to indices - indexedData = indexerModel.transform(data) - # $example off$ - - sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala deleted file mode 100644 index e724aa587294..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.Binarizer -// $example off$ -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.{SparkConf, SparkContext} - -object BinarizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("BinarizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - // $example on$ - val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) - val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - - val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - - val binarizedDataFrame = binarizer.transform(dataFrame) - val binarizedFeatures = binarizedDataFrame.select("binarized_feature") - binarizedFeatures.collect().foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala deleted file mode 100644 index 30c2776d3968..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.Bucketizer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object BucketizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("BucketizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - - val data = Array(-0.5, -0.3, 0.0, 0.2) - val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - - val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - - // Transform original data into its bucket index. - val bucketedData = bucketizer.transform(dataFrame) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala deleted file mode 100644 index 314c2c28a2a1..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object DCTExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("DCTExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) - - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - - val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) - - val dctDf = dct.transform(df) - dctDf.select("featuresDCT").show(3) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala deleted file mode 100644 index ac50bb7b2b15..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object ElementwiseProductExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("ElementwiseProductExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - // Create some vector data; also works for sparse vectors - val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - - val transformingVector = Vectors.dense(0.0, 1.0, 2.0) - val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - - // Batch transform the vectors to create new column: - transformer.transform(dataFrame).show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala deleted file mode 100644 index dac3679a5bf7..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.MinMaxScaler -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object MinMaxScalerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("MinMaxScalerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - - // Compute summary statistics and generate MinMaxScalerModel - val scalerModel = scaler.fit(dataFrame) - - // rescale each feature to range [min, max]. - val scaledData = scalerModel.transform(dataFrame) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala deleted file mode 100644 index 8a85f71b56f3..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.NGram -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object NGramExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("NGramExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) - )).toDF("label", "words") - - val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") - val ngramDataFrame = ngram.transform(wordDataFrame) - ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala deleted file mode 100644 index 17571f0aad79..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.Normalizer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object NormalizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("NormalizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - // Normalize each Vector using $L^1$ norm. - val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) - - val l1NormData = normalizer.transform(dataFrame) - - // Normalize each Vector using $L^\infty$ norm. - val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala deleted file mode 100644 index 4512736943dd..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object OneHotEncoderExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("OneHotEncoderExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - )).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) - val indexed = indexer.transform(df) - - val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") - val encoded = encoder.transform(indexed) - encoded.select("id", "categoryVec").foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala deleted file mode 100644 index a18d4f33973d..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object PCAExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("PCAExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) - ) - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) - val pcaDF = pca.transform(df) - val result = pcaDF.select("pcaFeatures") - result.show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala deleted file mode 100644 index b8e9e6952a5e..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object PolynomialExpansionExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("PolynomialExpansionExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) - ) - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) - val polyDF = polynomialExpansion.transform(df) - polyDF.select("polyFeatures").take(3).foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println - - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala deleted file mode 100644 index 286866edea50..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.RFormula -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object RFormulaExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("RFormulaExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) - )).toDF("id", "country", "hour", "clicked") - val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") - val output = formula.fit(dataset).transform(dataset) - output.select("features", "label").show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala deleted file mode 100644 index 646ce0f13ecf..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.StandardScaler -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object StandardScalerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StandardScalerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - - // Compute summary statistics by fitting the StandardScaler. - val scalerModel = scaler.fit(dataFrame) - - // Normalize each feature to have unit standard deviation. - val scaledData = scalerModel.transform(dataFrame) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala deleted file mode 100644 index 655ffce08d3a..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.StopWordsRemover -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object StopWordsRemoverExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StopWordsRemoverExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") - - val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) - )).toDF("id", "raw") - - remover.transform(dataSet).show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala deleted file mode 100644 index 1be8a5f33f7c..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.StringIndexer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object StringIndexerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StringIndexerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) - ).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - - val indexed = indexer.fit(df).transform(df) - indexed.show() - // $example off$ - sc.stop() - } -} -// scalastyle:on println - diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala deleted file mode 100644 index 01e0d1388a2f..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object TokenizerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("TokenizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") - )).toDF("label", "sentence") - - val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") - val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - - val tokenized = tokenizer.transform(sentenceDataFrame) - tokenized.select("words", "label").take(3).foreach(println) - val regexTokenized = regexTokenizer.transform(sentenceDataFrame) - regexTokenized.select("words", "label").take(3).foreach(println) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala deleted file mode 100644 index d527924419f8..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.mllib.linalg.Vectors -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object VectorAssemblerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorAssemblerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) - ).toDF("id", "hour", "mobile", "userFeatures", "clicked") - - val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") - - val output = assembler.transform(dataset) - println(output.select("features", "clicked").first()) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala deleted file mode 100644 index 14279d610fda..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.feature.VectorIndexer -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object VectorIndexerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorIndexerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) - - val indexerModel = indexer.fit(data) - - val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - - // Create new column "indexed" with categorical values transformed to indices - val indexedData = indexerModel.transform(data) - // $example off$ - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala deleted file mode 100644 index 04f19829eff8..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -// $example on$ -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.Row -import org.apache.spark.sql.types.StructType -// $example off$ -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - -object VectorSlicerExample { - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorSlicerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) - - val defaultAttr = NumericAttribute.defaultAttr - val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) - val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - - val dataRDD = sc.parallelize(data) - val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) - - val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - - slicer.setIndices(Array(1)).setNames(Array("f3")) - // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - - val output = slicer.transform(dataset) - println(output.select("userFeatures", "features").first()) - // $example off$ - sc.stop() - } -} -// scalastyle:on println From e3735ce1602826f0a8e0ca9e08730923843449ee Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 8 Dec 2015 14:34:47 +0000 Subject: [PATCH 1695/1824] [SPARK-11652][CORE] Remote code execution with InvokerTransformer Fix commons-collection group ID to commons-collections for version 3.x Patches earlier PR at https://github.com/apache/spark/pull/9731 Author: Sean Owen Closes #10198 from srowen/SPARK-11652.2. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index ae2ff8878b0a..5daca03f6143 100644 --- a/pom.xml +++ b/pom.xml @@ -478,7 +478,7 @@ ${commons.math3.version}
    - org.apache.commons + commons-collections commons-collections ${commons.collections.version} From 6cb06e8711fd6ac10c57faeb94bc323cae1cef27 Mon Sep 17 00:00:00 2001 From: Xin Ren Date: Tue, 8 Dec 2015 11:44:51 -0600 Subject: [PATCH 1696/1824] [SPARK-11155][WEB UI] Stage summary json should include stage duration The json endpoint for stages doesn't include information on the stage duration that is present in the UI. This looks like a simple oversight, they should be included. eg., the metrics should be included at api/v1/applications//stages. Metrics I've added are: submissionTime, firstTaskLaunchedTime and completionTime Author: Xin Ren Closes #10107 from keypointt/SPARK-11155. --- .../status/api/v1/AllStagesResource.scala | 14 ++++- .../org/apache/spark/status/api/v1/api.scala | 3 + .../complete_stage_list_json_expectation.json | 11 +++- .../failed_stage_list_json_expectation.json | 5 +- .../one_stage_attempt_json_expectation.json | 5 +- .../one_stage_json_expectation.json | 5 +- .../stage_list_json_expectation.json | 14 ++++- ...ist_with_accumulable_json_expectation.json | 5 +- ...age_with_accumulable_json_expectation.json | 5 +- .../api/v1/AllStagesResourceSuite.scala | 62 +++++++++++++++++++ project/MimaExcludes.scala | 4 +- 11 files changed, 124 insertions(+), 9 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 24a0b5220695..31b4dd7c0f42 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -17,8 +17,8 @@ package org.apache.spark.status.api.v1 import java.util.{Arrays, Date, List => JList} -import javax.ws.rs.{GET, PathParam, Produces, QueryParam} import javax.ws.rs.core.MediaType +import javax.ws.rs.{GET, Produces, QueryParam} import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} @@ -59,6 +59,15 @@ private[v1] object AllStagesResource { stageUiData: StageUIData, includeDetails: Boolean): StageData = { + val taskLaunchTimes = stageUiData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + + val firstTaskLaunchedTime: Option[Date] = + if (taskLaunchTimes.nonEmpty) { + Some(new Date(taskLaunchTimes.min)) + } else { + None + } + val taskData = if (includeDetails) { Some(stageUiData.taskData.map { case (k, v) => k -> convertTaskData(v) } ) } else { @@ -92,6 +101,9 @@ private[v1] object AllStagesResource { numCompleteTasks = stageUiData.numCompleteTasks, numFailedTasks = stageUiData.numFailedTasks, executorRunTime = stageUiData.executorRunTime, + submissionTime = stageInfo.submissionTime.map(new Date(_)), + firstTaskLaunchedTime, + completionTime = stageInfo.completionTime.map(new Date(_)), inputBytes = stageUiData.inputBytes, inputRecords = stageUiData.inputRecords, outputBytes = stageUiData.outputBytes, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index baddfc50c1a4..5feb1dc2e5b7 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -120,6 +120,9 @@ class StageData private[spark]( val numFailedTasks: Int, val executorRunTime: Long, + val submissionTime: Option[Date], + val firstTaskLaunchedTime: Option[Date], + val completionTime: Option[Date], val inputBytes: Long, val inputRecords: Long, diff --git a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json index 31ac9beea878..8f8067f86d57 100644 --- a/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/complete_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -64,4 +73,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line9.$read$$iwC$$iwC$$iwC$$iwC.(:15)\n$line9.$read$$iwC$$iwC$$iwC.(:20)\n$line9.$read$$iwC$$iwC.(:22)\n$line9.$read$$iwC.(:24)\n$line9.$read.(:26)\n$line9.$read$.(:30)\n$line9.$read$.()\n$line9.$eval$.(:7)\n$line9.$eval$.()\n$line9.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json index bff6a4f69d07..08b692eda802 100644 --- a/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/failed_stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -20,4 +23,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 111cb8163eb3..b07011d4f113 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index ef339f89afa4..2f71520549e1 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -267,4 +270,4 @@ "diskBytesSpilled" : 0 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json index 056fac708859..5b957ed54955 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 162, + "submissionTime" : "2015-02-03T16:43:07.191GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:07.191GMT", + "completionTime" : "2015-02-03T16:43:07.226GMT", "inputBytes" : 160, "inputRecords" : 0, "outputBytes" : 0, @@ -28,6 +31,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 3476, + "submissionTime" : "2015-02-03T16:43:05.829GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:05.829GMT", + "completionTime" : "2015-02-03T16:43:06.286GMT", "inputBytes" : 28000128, "inputRecords" : 0, "outputBytes" : 0, @@ -50,6 +56,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 4338, + "submissionTime" : "2015-02-03T16:43:04.228GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:04.234GMT", + "completionTime" : "2015-02-03T16:43:04.819GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -72,6 +81,9 @@ "numCompleteTasks" : 7, "numFailedTasks" : 1, "executorRunTime" : 278, + "submissionTime" : "2015-02-03T16:43:06.296GMT", + "firstTaskLaunchedTime" : "2015-02-03T16:43:06.296GMT", + "completionTime" : "2015-02-03T16:43:06.347GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -86,4 +98,4 @@ "details" : "org.apache.spark.rdd.RDD.count(RDD.scala:910)\n$line11.$read$$iwC$$iwC$$iwC$$iwC.(:20)\n$line11.$read$$iwC$$iwC$$iwC.(:25)\n$line11.$read$$iwC$$iwC.(:27)\n$line11.$read$$iwC.(:29)\n$line11.$read.(:31)\n$line11.$read$.(:35)\n$line11.$read$.()\n$line11.$eval$.(:7)\n$line11.$eval$.()\n$line11.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:606)\norg.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)\norg.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)\norg.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)\norg.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)", "schedulingPool" : "default", "accumulatorUpdates" : [ ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json index 79ccacd30969..afa425f8c27b 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_list_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -24,4 +27,4 @@ "name" : "my counter", "value" : "5050" } ] -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 32d5731676ad..12665a152c9e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -6,6 +6,9 @@ "numCompleteTasks" : 8, "numFailedTasks" : 0, "executorRunTime" : 120, + "submissionTime" : "2015-03-16T19:25:36.103GMT", + "firstTaskLaunchedTime" : "2015-03-16T19:25:36.515GMT", + "completionTime" : "2015-03-16T19:25:36.579GMT", "inputBytes" : 0, "inputRecords" : 0, "outputBytes" : 0, @@ -239,4 +242,4 @@ "diskBytesSpilled" : 0 } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala new file mode 100644 index 000000000000..88817dccf349 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1 + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.{StageInfo, TaskInfo, TaskLocality} +import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} + +class AllStagesResourceSuite extends SparkFunSuite { + + def getFirstTaskLaunchTime(taskLaunchTimes: Seq[Long]): Option[Date] = { + val tasks = new HashMap[Long, TaskUIData] + taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => + tasks(idx.toLong) = new TaskUIData( + new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None, None) + } + + val stageUiData = new StageUIData() + stageUiData.taskData = tasks + val status = StageStatus.ACTIVE + val stageInfo = new StageInfo( + 1, 1, "stage 1", 10, Seq.empty, Seq.empty, "details abc", Seq.empty) + val stageData = AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, false) + + stageData.firstTaskLaunchedTime + } + + test("firstTaskLaunchedTime when there are no tasks") { + val result = getFirstTaskLaunchTime(Seq()) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks but none launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, -200L, -300L)) + assert(result == None) + } + + test("firstTaskLaunchedTime when there are tasks and some launched") { + val result = getFirstTaskLaunchTime(Seq(-100L, 1449255596000L, 1449255597000L)) + assert(result == Some(new Date(1449255596000L))) + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b4aa6adc3c62..685cb419ca8a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -132,7 +132,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") ) ++ Seq ( ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationInfo.this") + "org.apache.spark.status.api.v1.ApplicationInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.StageData.this") ) ++ Seq( // SPARK-11766 add toJson to Vector ProblemFilters.exclude[MissingMethodProblem]( From 75c60bf4ba91e45e76a6e27f054a1c550eb6ff94 Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 8 Dec 2015 10:01:44 -0800 Subject: [PATCH 1697/1824] [SPARK-12074] Avoid memory copy involving ByteBuffer.wrap(ByteArrayOutputStream.toByteArray) SPARK-12060 fixed JavaSerializerInstance.serialize This PR applies the same technique on two other classes. zsxwing Author: tedyu Closes #10177 from tedyu/master. --- core/src/main/scala/org/apache/spark/scheduler/Task.scala | 7 +++---- .../main/scala/org/apache/spark/storage/BlockManager.scala | 4 ++-- .../org/apache/spark/util/ByteBufferOutputStream.scala | 4 +++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5fe5ae8c4581..d4bc3a5c900f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -27,8 +27,7 @@ import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.ByteBufferInputStream -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** @@ -172,7 +171,7 @@ private[spark] object Task { serializer: SerializerInstance) : ByteBuffer = { - val out = new ByteArrayOutputStream(4096) + val out = new ByteBufferOutputStream(4096) val dataOut = new DataOutputStream(out) // Write currentFiles @@ -193,7 +192,7 @@ private[spark] object Task { dataOut.flush() val taskBytes = serializer.serialize(task) Utils.writeByteBuffer(taskBytes, out) - ByteBuffer.wrap(out.toByteArray) + out.toByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ab0007fb7899..ed05143877e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1202,9 +1202,9 @@ private[spark] class BlockManager( blockId: BlockId, values: Iterator[Any], serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new ByteArrayOutputStream(4096) + val byteStream = new ByteBufferOutputStream(4096) dataSerializeStream(blockId, byteStream, values, serializer) - ByteBuffer.wrap(byteStream.toByteArray) + byteStream.toByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala index 92e45224db81..8527e3ae692e 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -23,7 +23,9 @@ import java.nio.ByteBuffer /** * Provide a zero-copy way to convert data in ByteArrayOutputStream to ByteBuffer */ -private[spark] class ByteBufferOutputStream extends ByteArrayOutputStream { +private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutputStream(capacity) { + + def this() = this(32) def toByteBuffer: ByteBuffer = { return ByteBuffer.wrap(buf, 0, count) From 381f17b540d92507cc07adf18bce8bc7e5ca5407 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Dec 2015 10:13:40 -0800 Subject: [PATCH 1698/1824] [SPARK-12201][SQL] add type coercion rule for greatest/least checked with hive, greatest/least should cast their children to a tightest common type, i.e. `(int, long) => long`, `(int, string) => error`, `(decimal(10,5), decimal(5, 10)) => error` Author: Wenchen Fan Closes #10196 from cloud-fan/type-coercion. --- .../catalyst/analysis/HiveTypeCoercion.scala | 14 +++++++++++ .../ExpressionTypeCheckingSuite.scala | 10 ++++++++ .../analysis/HiveTypeCoercionSuite.scala | 23 +++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 29502a59915f..dbcbd6854b47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -594,6 +594,20 @@ object HiveTypeCoercion { case None => c } + case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case None => g + } + + case l @ Least(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case None => l + } + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ba1866efc84e..915c585ec91f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -32,6 +32,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { 'intField.int, 'stringField.string, 'booleanField.boolean, + 'decimalField.decimal(8, 0), 'arrayField.array(StringType), 'mapField.map(StringType, LongType)) @@ -189,4 +190,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") } + + test("check types for Greatest/Least") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('intField, 'stringField)), "should all have the same type") + assertError(operator(Seq('intField, 'decimalField)), "should all have the same type") + assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index d3fafaae8993..142915056f45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -251,6 +251,29 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("greatest/least cast") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0, FloatType) + :: Nil), + operator(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Nil)) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + operator(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Nil)) + } + } + test("nanvl casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), From c0b13d5565c45ae2acbe8cfb17319c92b6a634e4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 Dec 2015 10:15:58 -0800 Subject: [PATCH 1699/1824] [SPARK-12195][SQL] Adding BigDecimal, Date and Timestamp into Encoder This PR is to add three more data types into Encoder, including `BigDecimal`, `Date` and `Timestamp`. marmbrus cloud-fan rxin Could you take a quick look at these three types? Not sure if it can be merged to 1.6. Thank you very much! Author: gatorsmile Closes #10188 from gatorsmile/dataTypesinEncoder. --- .../scala/org/apache/spark/sql/Encoder.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/JavaDatasetSuite.java | 17 +++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index c40061ae0aaf..3ca5ade7f30f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -97,6 +97,24 @@ object Encoders { */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() + /** + * An encoder for nullable decimal type. + * @since 1.6.0 + */ + def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() + + /** + * An encoder for nullable date type. + * @since 1.6.0 + */ + def DATE: Encoder[java.sql.Date] = ExpressionEncoder() + + /** + * An encoder for nullable timestamp type. + * @since 1.6.0 + */ + def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() + /** * Creates an encoder for Java Bean of type T. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ae47f4fe0e23..383a2d0badb5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -18,6 +18,9 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; import java.util.*; import scala.Tuple2; @@ -385,6 +388,20 @@ public void testNestedTupleEncoder() { Assert.assertEquals(data3, ds3.collectAsList()); } + @Test + public void testPrimitiveEncoder() { + Encoder> encoder = + Encoders.tuple(Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(), + Encoders.FLOAT()); + List> data = + Arrays.asList(new Tuple5( + 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), + Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); + Dataset> ds = + context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + @Test public void testTypedAggregation() { Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); From 5d96a710a5ed543ec81e383620fc3b2a808b26a1 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 Dec 2015 10:25:57 -0800 Subject: [PATCH 1700/1824] [SPARK-12188][SQL] Code refactoring and comment correction in Dataset APIs This PR contains the following updates: - Created a new private variable `boundTEncoder` that can be shared by multiple functions, `RDD`, `select` and `collect`. - Replaced all the `queryExecution.analyzed` by the function call `logicalPlan` - A few API comments are using wrong class names (e.g., `DataFrame`) or parameter names (e.g., `n`) - A few API descriptions are wrong. (e.g., `mapPartitions`) marmbrus rxin cloud-fan Could you take a look and check if they are appropriate? Thank you! Author: gatorsmile Closes #10184 from gatorsmile/datasetClean. --- .../scala/org/apache/spark/sql/Dataset.scala | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d6bb1d2ad8e5..3bd18a14f9e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -67,15 +67,21 @@ class Dataset[T] private[sql]( tEncoder: Encoder[T]) extends Queryable with Serializable { /** - * An unresolved version of the internal encoder for the type of this dataset. This one is marked - * implicit so that we can use it when constructing new [[Dataset]] objects that have the same - * object type (that will be possibly resolved to a different schema). + * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is + * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the + * same object type (that will be possibly resolved to a different schema). */ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) + unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + + /** + * The encoder where the expressions used to construct an object from an input row have been + * bound to the ordinals of the given schema. + */ + private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) private implicit def classTag = resolvedTEncoder.clsTag @@ -89,7 +95,7 @@ class Dataset[T] private[sql]( override def schema: StructType = resolvedTEncoder.schema /** - * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format. + * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format. * @since 1.6.0 */ override def printSchema(): Unit = toDF().printSchema() @@ -111,7 +117,7 @@ class Dataset[T] private[sql]( * ************* */ /** - * Returns a new `Dataset` where each record has been mapped on to the specified type. The + * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The * method used to map columns depend on the type of `U`: * - When `U` is a class, fields for the class will be mapped to columns of the same name * (case sensitivity is determined by `spark.sql.caseSensitive`) @@ -145,7 +151,7 @@ class Dataset[T] private[sql]( def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) /** - * Returns this Dataset. + * Returns this [[Dataset]]. * @since 1.6.0 */ // This is declared with parentheses to prevent the Scala compiler from treating @@ -153,15 +159,12 @@ class Dataset[T] private[sql]( def toDS(): Dataset[T] = this /** - * Converts this Dataset to an RDD. + * Converts this [[Dataset]] to an [[RDD]]. * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = resolvedTEncoder - val input = queryExecution.analyzed.output queryExecution.toRdd.mapPartitions { iter => - val bound = tEnc.bind(input) - iter.map(bound.fromRow) + iter.map(boundTEncoder.fromRow) } } @@ -189,7 +192,7 @@ class Dataset[T] private[sql]( def show(numRows: Int): Unit = show(numRows, truncate = true) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters * will be truncated, and all cells will be aligned right. * * @since 1.6.0 @@ -197,7 +200,7 @@ class Dataset[T] private[sql]( def show(): Unit = show(20) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * Displays the top 20 rows of [[Dataset]] in a tabular form. * * @param truncate Whether truncate long strings. If true, strings more than 20 characters will * be truncated and all cells will be aligned right @@ -207,7 +210,7 @@ class Dataset[T] private[sql]( def show(truncate: Boolean): Unit = show(20, truncate) /** - * Displays the [[DataFrame]] in a tabular form. For example: + * Displays the [[Dataset]] in a tabular form. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -291,7 +294,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { @@ -307,7 +310,7 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. * @since 1.6.0 */ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { @@ -341,28 +344,28 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Runs `func` on each element of this Dataset. + * Runs `func` on each element of this [[Dataset]]. * @since 1.6.0 */ def foreach(func: T => Unit): Unit = rdd.foreach(func) /** * (Java-specific) - * Runs `func` on each element of this Dataset. + * Runs `func` on each element of this [[Dataset]]. * @since 1.6.0 */ def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) /** * (Scala-specific) - * Runs `func` on each partition of this Dataset. + * Runs `func` on each partition of this [[Dataset]]. * @since 1.6.0 */ def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) /** * (Java-specific) - * Runs `func` on each partition of this Dataset. + * Runs `func` on each partition of this [[Dataset]]. * @since 1.6.0 */ def foreachPartition(func: ForeachPartitionFunction[T]): Unit = @@ -374,7 +377,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func` * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -382,7 +385,7 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given `func` * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -390,11 +393,11 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. * @since 1.6.0 */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { - val inputPlan = queryExecution.analyzed + val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) @@ -429,18 +432,18 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. * @since 1.6.0 */ - def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = - groupBy(f.call(_))(encoder) + def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + groupBy(func.call(_))(encoder) /* ****************** * * Typed Relational * * ****************** */ /** - * Selects a set of column based expressions. + * Returns a new [[DataFrame]] by selecting a set of column based expressions. * {{{ * df.select($"colA", $"colB" + 1) * }}} @@ -464,8 +467,8 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder.bind(queryExecution.analyzed.output), - queryExecution.analyzed.output).named :: Nil, + boundTEncoder, + logicalPlan.output).named :: Nil, logicalPlan)) } @@ -477,7 +480,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named) + columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) @@ -654,7 +657,7 @@ class Dataset[T] private[sql]( * Returns an array that contains all the elements in this [[Dataset]]. * * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. * * For Java API, use [[collectAsList]]. * @since 1.6.0 @@ -662,17 +665,14 @@ class Dataset[T] private[sql]( def collect(): Array[T] = { // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders // to convert the rows into objects of type T. - val tEnc = resolvedTEncoder - val input = queryExecution.analyzed.output - val bound = tEnc.bind(input) - queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow) + queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) } /** * Returns an array that contains all the elements in this [[Dataset]]. * * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. * * For Java API, use [[collectAsList]]. * @since 1.6.0 @@ -683,7 +683,7 @@ class Dataset[T] private[sql]( * Returns the first `num` elements of this [[Dataset]] as an array. * * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * a very large `num` can crash the driver process with OutOfMemoryError. * @since 1.6.0 */ def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() @@ -692,7 +692,7 @@ class Dataset[T] private[sql]( * Returns the first `num` elements of this [[Dataset]] as an array. * * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * a very large `num` can crash the driver process with OutOfMemoryError. * @since 1.6.0 */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) From 872a2ee281d84f40a786f765bf772cdb06e8c956 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 8 Dec 2015 10:29:51 -0800 Subject: [PATCH 1701/1824] [SPARK-10393] use ML pipeline in LDA example jira: https://issues.apache.org/jira/browse/SPARK-10393 Since the logic of the text processing part has been moved to ML estimators/transformers, replace the related code in LDA Example with the ML pipeline. Author: Yuhao Yang Author: yuhaoyang Closes #8551 from hhbyyh/ldaExUpdate. --- .../spark/examples/mllib/LDAExample.scala | 153 +++++------------- 1 file changed, 40 insertions(+), 113 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 75b0f69cf91a..70010b05e434 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,19 +18,16 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import java.text.BreakIterator - -import scala.collection.mutable - import scopt.OptionParser import org.apache.log4j.{Level, Logger} - -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} +import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD - +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} /** * An example Latent Dirichlet Allocation (LDA) app. Run with @@ -192,115 +189,45 @@ object LDAExample { vocabSize: Int, stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + // Get dataset of document texts // One document per line in each text file. If the input consists of many small files, // this can result in a large number of small partitions, which can degrade performance. // In this case, consider using coalesce() to create fewer, larger partitions. - val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) - - // Split text into words - val tokenizer = new SimpleTokenizer(sc, stopwordFile) - val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => - id -> tokenizer.getWords(text) - } - tokenized.cache() - - // Counts words: RDD[(word, wordCount)] - val wordCounts: RDD[(String, Long)] = tokenized - .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } - .reduceByKey(_ + _) - wordCounts.cache() - val fullVocabSize = wordCounts.count() - // Select vocab - // (vocab: Map[word -> id], total tokens after selecting vocab) - val (vocab: Map[String, Int], selectedTokenCount: Long) = { - val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocabSize) - } - (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) - } - - val documents = tokenized.map { case (id, tokens) => - // Filter tokens by vocabulary, and create word count vector representation of document. - val wc = new mutable.HashMap[Int, Int]() - tokens.foreach { term => - if (vocab.contains(term)) { - val termIndex = vocab(term) - wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 - } - } - val indices = wc.keys.toArray.sorted - val values = indices.map(i => wc(i).toDouble) - - val sb = Vectors.sparse(vocab.size, indices, values) - (id, sb) - } - - val vocabArray = new Array[String](vocab.size) - vocab.foreach { case (term, i) => vocabArray(i) = term } - - (documents, vocabArray, selectedTokenCount) - } -} - -/** - * Simple Tokenizer. - * - * TODO: Formalize the interface, and make this a public class in mllib.feature - */ -private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { - - private val stopwords: Set[String] = if (stopwordFile.isEmpty) { - Set.empty[String] - } else { - val stopwordText = sc.textFile(stopwordFile).collect() - stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet - } - - // Matches sequences of Unicode letters - private val allWordRegex = "^(\\p{L}*)$".r - - // Ignore words shorter than this length. - private val minWordLength = 3 - - def getWords(text: String): IndexedSeq[String] = { - - val words = new mutable.ArrayBuffer[String]() - - // Use Java BreakIterator to tokenize text into words. - val wb = BreakIterator.getWordInstance - wb.setText(text) - - // current,end index start,end of each word - var current = wb.first() - var end = wb.next() - while (end != BreakIterator.DONE) { - // Convert to lowercase - val word: String = text.substring(current, end).toLowerCase - // Remove short words and strings that aren't only letters - word match { - case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => - words += w - case _ => - } - - current = end - try { - end = wb.next() - } catch { - case e: Exception => - // Ignore remaining text in line. - // This is a known bug in BreakIterator (for some Java versions), - // which fails when it sees certain characters. - end = BreakIterator.DONE - } + val df = sc.textFile(paths.mkString(",")).toDF("docs") + val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) { + Array.empty[String] + } else { + val stopWordText = sc.textFile(stopwordFile).collect() + stopWordText.flatMap(_.stripMargin.split("\\s+")) } - words + val tokenizer = new RegexTokenizer() + .setInputCol("docs") + .setOutputCol("rawTokens") + val stopWordsRemover = new StopWordsRemover() + .setInputCol("rawTokens") + .setOutputCol("tokens") + stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords) + val countVectorizer = new CountVectorizer() + .setVocabSize(vocabSize) + .setInputCol("tokens") + .setOutputCol("features") + + val pipeline = new Pipeline() + .setStages(Array(tokenizer, stopWordsRemover, countVectorizer)) + + val model = pipeline.fit(df) + val documents = model.transform(df) + .select("features") + .map { case Row(features: Vector) => features } + .zipWithIndex() + .map(_.swap) + + (documents, + model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary + documents.map(_._2.numActives).sum().toLong) // total token count } - } // scalastyle:on println From 4bcb894948c1b7294d84e2bf58abb1d79e6759c6 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Tue, 8 Dec 2015 10:52:17 -0800 Subject: [PATCH 1702/1824] [SPARK-12205][SQL] Pivot fails Analysis when aggregate is UnresolvedFunction Delays application of ResolvePivot until all aggregates are resolved to prevent problems with UnresolvedFunction and adds unit test Author: Andrew Ray Closes #10202 from aray/sql-pivot-unresolved-function. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../scala/org/apache/spark/sql/DataFramePivotSuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d3163dcd4db9..ca00a5e49f66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -259,7 +259,7 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved => p + case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index fc53aba68ebb..bc1a336ea4fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -85,4 +85,12 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) } + + test("pivot with UnresolvedFunction") { + checkAnswer( + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) + .agg("earnings" -> "sum"), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } } From 5cb4695051e3dac847b1ea14d62e54dcf672c31c Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 8 Dec 2015 11:46:26 -0800 Subject: [PATCH 1703/1824] [SPARK-11605][MLLIB] ML 1.6 QA: API: Java compatibility, docs jira: https://issues.apache.org/jira/browse/SPARK-11605 Check Java compatibility for MLlib for this release. fix: 1. `StreamingTest.registerStream` needs java friendly interface. 2. `GradientBoostedTreesModel.computeInitialPredictionAndError` and `GradientBoostedTreesModel.updatePredictionError` has java compatibility issue. Mark them as `developerAPI`. TBD: [updated] no fix for now per discussion. `org.apache.spark.mllib.classification.LogisticRegressionModel` `public scala.Option getThreshold();` has wrong return type for Java invocation. `SVMModel` has the similar issue. Yet adding a `scala.Option getThreshold()` would result in an overloading error due to the same function signature. And adding a new function with different name seems to be not necessary. cc jkbradley feynmanliang Author: Yuhao Yang Closes #10102 from hhbyyh/javaAPI. --- .../examples/mllib/StreamingTestExample.scala | 4 +- .../spark/mllib/stat/test/StreamingTest.scala | 50 +++++++++++++++---- .../mllib/tree/model/treeEnsembleModels.scala | 6 ++- .../spark/mllib/stat/JavaStatisticsSuite.java | 38 ++++++++++++-- .../spark/mllib/stat/StreamingTestSuite.scala | 25 +++++----- 5 files changed, 96 insertions(+), 27 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index b6677c647663..49f5df39443e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -18,7 +18,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.SparkConf -import org.apache.spark.mllib.stat.test.StreamingTest +import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.util.Utils @@ -66,7 +66,7 @@ object StreamingTestExample { // $example on$ val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { - case Array(label, value) => (label.toBoolean, value.toDouble) + case Array(label, value) => BinarySample(label.toBoolean, value.toDouble) }) val streamingTest = new StreamingTest() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 75c6a51d0957..e990fe0768bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -17,12 +17,30 @@ package org.apache.spark.mllib.stat.test +import scala.beans.BeanInfo + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.api.java.JavaDStream import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter +/** + * Class that represents the group and value of a sample. + * + * @param isExperiment if the sample is of the experiment group. + * @param value numeric value of the observation. + */ +@Since("1.6.0") +@BeanInfo +case class BinarySample @Since("1.6.0") ( + @Since("1.6.0") isExperiment: Boolean, + @Since("1.6.0") value: Double) { + override def toString: String = { + s"($isExperiment, $value)" + } +} + /** * :: Experimental :: * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The @@ -83,13 +101,13 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { /** * Register a [[DStream]] of values for significance testing. * - * @param data stream of (key,value) pairs where the key denotes group membership (true = - * experiment, false = control) and the value is the numerical metric to test for - * significance + * @param data stream of BinarySample(key,value) pairs where the key denotes group membership + * (true = experiment, false = control) and the value is the numerical metric to + * test for significance * @return stream of significance testing results */ @Since("1.6.0") - def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { + def registerStream(data: DStream[BinarySample]): DStream[StreamingTestResult] = { val dataAfterPeacePeriod = dropPeacePeriod(data) val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) val pairedSummaries = pairSummaries(summarizedData) @@ -97,9 +115,22 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { testMethod.doTest(pairedSummaries) } + /** + * Register a [[JavaDStream]] of values for significance testing. + * + * @param data stream of BinarySample(isExperiment,value) pairs where the isExperiment denotes + * group (true = experiment, false = control) and the value is the numerical metric + * to test for significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: JavaDStream[BinarySample]): JavaDStream[StreamingTestResult] = { + JavaDStream.fromDStream(registerStream(data.dstream)) + } + /** Drop all batches inside the peace period. */ private[stat] def dropPeacePeriod( - data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data: DStream[BinarySample]): DStream[BinarySample] = { data.transform { (rdd, time) => if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { rdd @@ -111,9 +142,10 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { /** Compute summary statistics over each key and the specified test window size. */ private[stat] def summarizeByKeyAndWindow( - data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + data: DStream[BinarySample]): DStream[(Boolean, StatCounter)] = { + val categoryValuePair = data.map(sample => (sample.isExperiment, sample.value)) if (this.windowSize == 0) { - data.updateStateByKey[StatCounter]( + categoryValuePair.updateStateByKey[StatCounter]( (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { val newSummary = oldSummary.getOrElse(new StatCounter()) newSummary.merge(newValues) @@ -121,7 +153,7 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable { }) } else { val windowDuration = data.slideDuration * this.windowSize - data + categoryValuePair .groupByKeyAndWindow(windowDuration) .mapValues { values => val summary = new StatCounter() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 3f427f0be3af..feabcee24fa2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Since +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -186,6 +186,7 @@ class GradientBoostedTreesModel @Since("1.2.0") ( object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { /** + * :: DeveloperApi :: * Compute the initial predictions and errors for a dataset for the first * iteration of gradient boosting. * @param data: training data. @@ -196,6 +197,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * corresponding to every sample. */ @Since("1.4.0") + @DeveloperApi def computeInitialPredictionAndError( data: RDD[LabeledPoint], initTreeWeight: Double, @@ -209,6 +211,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { } /** + * :: DeveloperApi :: * Update a zipped predictionError RDD * (as obtained with computeInitialPredictionAndError) * @param data: training data. @@ -220,6 +223,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * corresponding to each sample. */ @Since("1.4.0") + @DeveloperApi def updatePredictionError( data: RDD[LabeledPoint], predictionAndError: RDD[(Double, Double)], diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 4795809e47a4..66b2ceacb05f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -18,34 +18,49 @@ package org.apache.spark.mllib.stat; import java.io.Serializable; - import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; +import static org.apache.spark.streaming.JavaTestUtils.*; import static org.junit.Assert.assertEquals; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.BinarySample; import org.apache.spark.mllib.stat.test.ChiSqTestResult; import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; +import org.apache.spark.mllib.stat.test.StreamingTest; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; + private transient JavaStreamingContext ssc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaStatistics"); + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("JavaStatistics") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + sc = new JavaSparkContext(conf); + ssc = new JavaStreamingContext(sc, new Duration(1000)); + ssc.checkpoint("checkpoint"); } @After public void tearDown() { - sc.stop(); + ssc.stop(); + ssc = null; sc = null; } @@ -76,4 +91,21 @@ public void chiSqTest() { new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); ChiSqTestResult[] testResults = Statistics.chiSqTest(data); } + + @Test + public void streamingTest() { + List trainingBatch = Arrays.asList( + new BinarySample(true, 1.0), + new BinarySample(false, 2.0)); + JavaDStream training = + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + int numBatches = 2; + StreamingTest model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod("welch"); + model.registerStream(training); + attachTestOutputStream(training); + runStreams(ssc, numBatches, numBatches); + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index d3e9ef4ff079..3c657c8cfe74 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.mllib.stat import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, + WelchTTest, BinarySample} import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter @@ -48,7 +49,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -75,7 +76,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -102,7 +103,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) @@ -130,7 +131,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(res => @@ -157,7 +158,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( input, - (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + (inputDStream: DStream[BinarySample]) => model.summarizeByKeyAndWindow(inputDStream)) val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) val outputCounts = outputBatches.flatten.map(_._2.count) @@ -190,7 +191,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.dropPeacePeriod(inputDStream)) val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) @@ -210,11 +211,11 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { .setPeacePeriod(0) val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) - .map(batch => batch.filter(_._1)) // only keep one test group + .map(batch => batch.filter(_.isExperiment)) // only keep one test group // setup and run the model val ssc = setupStreams( - input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream)) val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) @@ -228,13 +229,13 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { stdevA: Double, meanB: Double, stdevB: Double, - seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + seed: Int): (IndexedSeq[IndexedSeq[BinarySample]]) = { val rand = new XORShiftRandom(seed) val numTrues = pointsPerBatch / 2 val data = (0 until numBatches).map { i => - (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (0 until numTrues).map { idx => BinarySample(true, meanA + stdevA * rand.nextGaussian())} ++ (pointsPerBatch / 2 until pointsPerBatch).map { idx => - (false, meanB + stdevB * rand.nextGaussian()) + BinarySample(false, meanB + stdevB * rand.nextGaussian()) } } From 06746b3005e5e9892d0314bee3bfdfaebc36d3d4 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Tue, 8 Dec 2015 12:45:34 -0800 Subject: [PATCH 1704/1824] [SPARK-12159][ML] Add user guide section for IndexToString transformer Documentation regarding the `IndexToString` label transformer with code snippets in Scala/Java/Python. Author: BenFradet Closes #10166 from BenFradet/SPARK-12159. --- docs/ml-features.md | 104 +++++++++++++++--- .../examples/ml/JavaIndexToStringExample.java | 75 +++++++++++++ .../main/python/ml/index_to_string_example.py | 45 ++++++++ .../examples/ml/IndexToStringExample.scala | 60 ++++++++++ 4 files changed, 268 insertions(+), 16 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java create mode 100644 examples/src/main/python/ml/index_to_string_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 01d6abeb5ba6..e15c26836aff 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -835,10 +835,10 @@ dctDf.select("featuresDCT").show(3); `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies. So the most frequent label gets index `0`. -If the input column is numeric, we cast it to string and index the string -values. When downstream pipeline components such as `Estimator` or -`Transformer` make use of this string-indexed label, you must set the input -column of the component to this string-indexed column name. In many cases, +If the input column is numeric, we cast it to string and index the string +values. When downstream pipeline components such as `Estimator` or +`Transformer` make use of this string-indexed label, you must set the input +column of the component to this string-indexed column name. In many cases, you can set the input column with `setInputCol`. **Examples** @@ -951,9 +951,78 @@ indexed.show() + +## IndexToString + +Symmetrically to `StringIndexer`, `IndexToString` maps a column of label indices +back to a column containing the original labels as strings. The common use case +is to produce indices from labels with `StringIndexer`, train a model with those +indices and retrieve the original labels from the column of predicted indices +with `IndexToString`. However, you are free to supply your own labels. + +**Examples** + +Building on the `StringIndexer` example, let's assume we have the following +DataFrame with columns `id` and `categoryIndex`: + +~~~~ + id | categoryIndex +----|--------------- + 0 | 0.0 + 1 | 2.0 + 2 | 1.0 + 3 | 0.0 + 4 | 0.0 + 5 | 1.0 +~~~~ + +Applying `IndexToString` with `categoryIndex` as the input column, +`originalCategory` as the output column, we are able to retrieve our original +labels (they will be inferred from the columns' metadata): + +~~~~ + id | categoryIndex | originalCategory +----|---------------|----------------- + 0 | 0.0 | a + 1 | 2.0 | b + 2 | 1.0 | c + 3 | 0.0 | a + 4 | 0.0 | a + 5 | 1.0 | c +~~~~ + +
    +
    + +Refer to the [IndexToString Scala docs](api/scala/index.html#org.apache.spark.ml.feature.IndexToString) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/IndexToStringExample.scala %} + +
    + +
    + +Refer to the [IndexToString Java docs](api/java/org/apache/spark/ml/feature/IndexToString.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaIndexToStringExample.java %} + +
    + +
    + +Refer to the [IndexToString Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IndexToString) +for more details on the API. + +{% include_example python/ml/index_to_string_example.py %} + +
    +
    + ## OneHotEncoder -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
    @@ -979,10 +1048,11 @@ val indexer = new StringIndexer() .fit(df) val indexed = indexer.transform(df) -val encoder = new OneHotEncoder().setInputCol("categoryIndex"). - setOutputCol("categoryVec") +val encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec") val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").foreach(println) +encoded.select("id", "categoryVec").show() {% endhighlight %}
    @@ -1015,7 +1085,7 @@ JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(5, "c") )); StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("category", DataTypes.StringType, false, Metadata.empty()) }); DataFrame df = sqlContext.createDataFrame(jrdd, schema); @@ -1029,6 +1099,7 @@ OneHotEncoder encoder = new OneHotEncoder() .setInputCol("categoryIndex") .setOutputCol("categoryVec"); DataFrame encoded = encoder.transform(indexed); +encoded.select("id", "categoryVec").show(); {% endhighlight %}
    @@ -1054,6 +1125,7 @@ model = stringIndexer.fit(df) indexed = model.transform(df) encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") encoded = encoder.transform(indexed) +encoded.select("id", "categoryVec").show() {% endhighlight %} @@ -1582,7 +1654,7 @@ from pyspark.mllib.linalg import Vectors data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), +transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), inputCol="vector", outputCol="transformedVector") transformer.transform(df).show() @@ -1837,15 +1909,15 @@ for more details on the API. sub-array of the original features. It is useful for extracting features from a vector column. `VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column -whose values are selected via those indices. There are two types of indices, +whose values are selected via those indices. There are two types of indices, 1. Integer indices that represents the indices into the vector, `setIndices()`; - 2. String indices that represents the names of features into the vector, `setNames()`. + 2. String indices that represents the names of features into the vector, `setNames()`. *This requires the vector column to have an `AttributeGroup` since the implementation matches on the name field of an `Attribute`.* -Specification by integer and string are both acceptable. Moreover, you can use integer index and +Specification by integer and string are both acceptable. Moreover, you can use integer index and string name simultaneously. At least one feature must be selected. Duplicate features are not allowed, so there can be no overlap between selected indices and names. Note that if names of features are selected, an exception will be threw out when encountering with empty input attributes. @@ -1858,9 +1930,9 @@ followed by the selected names (in the order given). Suppose that we have a DataFrame with the column `userFeatures`: ~~~ - userFeatures + userFeatures ------------------ - [0.0, 10.0, 0.5] + [0.0, 10.0, 0.5] ~~~ `userFeatures` is a vector column that contains three user features. Assuming that the first column @@ -1874,7 +1946,7 @@ column named `features`: [0.0, 10.0, 0.5] | [10.0, 0.5] ~~~ -Suppose also that we have a potential input attributes for the `userFeatures`, i.e. +Suppose also that we have a potential input attributes for the `userFeatures`, i.e. `["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them. ~~~ diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java new file mode 100644 index 000000000000..3ccd6993261e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.IndexToString; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaIndexToStringExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + IndexToString converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory"); + DataFrame converted = converter.transform(indexed); + converted.select("id", "originalCategory").show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/python/ml/index_to_string_example.py b/examples/src/main/python/ml/index_to_string_example.py new file mode 100644 index 000000000000..fb0ba2950bbd --- /dev/null +++ b/examples/src/main/python/ml/index_to_string_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.ml.feature import IndexToString, StringIndexer +# $example off$ +from pyspark.sql import SQLContext + +if __name__ == "__main__": + sc = SparkContext(appName="IndexToStringExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + + converter = IndexToString(inputCol="categoryIndex", outputCol="originalCategory") + converted = converter.transform(indexed) + + converted.select("id", "originalCategory").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala new file mode 100644 index 000000000000..52537e5bb568 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.feature.{StringIndexer, IndexToString} +// $example off$ + +object IndexToStringExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("IndexToStringExample") + val sc = new SparkContext(conf) + + val sqlContext = SQLContext.getOrCreate(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val converter = new IndexToString() + .setInputCol("categoryIndex") + .setOutputCol("originalCategory") + + val converted = converter.transform(indexed) + converted.select("id", "originalCategory").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 2ff17bcfb1818a1b1e357343cd507883dcfdd2ea Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 8 Dec 2015 13:13:56 -0800 Subject: [PATCH 1705/1824] [SPARK-3873][BUILD] Add style checker to enforce import ordering. The checker tries to follow as closely as possible the guidelines of the code style document, and makes some decisions where the guide is not clear. In particular: - wildcard imports come first when there are other imports in the same package - multi-import blocks come before single imports - lower-case names inside multi-import blocks come before others In some projects, such as graphx, there seems to be a convention to separate o.a.s imports from the project's own; to simplify the checker, I chose not to allow that, which is a strict interpretation of the code style guide, even though I think it makes sense. Since the checks are based on syntax only, some edge cases may generate spurious warnings; for example, when class names start with a lower case letter (and are thus treated as a package name by the checker). The checker is currently only generating warnings, and since there are many of those, the build output does get a little noisy. The idea is to fix the code (and the checker, as needed) little by little instead of having a huge change that touches everywhere. Author: Marcelo Vanzin Closes #6502 from vanzin/SPARK-3873. --- scalastyle-config.xml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 050c3f360476..dab1ebddc666 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -153,7 +153,7 @@ This file is divided into 3 sections: @VisibleForTesting @@ -203,6 +203,18 @@ This file is divided into 3 sections: + + + + java,scala,3rdParty,spark + javax?\..+ + scala\..+ + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + + From 9494521695a1f1526aae76c0aea34a3bead96251 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 8 Dec 2015 14:34:15 -0800 Subject: [PATCH 1706/1824] [SPARK-12187] *MemoryPool classes should not be fully public This patch tightens them to `private[memory]`. Author: Andrew Or Closes #10182 from andrewor14/memory-visibility. --- .../scala/org/apache/spark/memory/ExecutionMemoryPool.scala | 2 +- core/src/main/scala/org/apache/spark/memory/MemoryPool.scala | 2 +- .../main/scala/org/apache/spark/memory/StorageMemoryPool.scala | 2 +- .../scala/org/apache/spark/memory/UnifiedMemoryManager.scala | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala index 7825bae42587..9023e1ac012b 100644 --- a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -39,7 +39,7 @@ import org.apache.spark.Logging * @param lock a [[MemoryManager]] instance to synchronize on * @param poolName a human-readable name for this pool, for use in log messages */ -class ExecutionMemoryPool( +private[memory] class ExecutionMemoryPool( lock: Object, poolName: String ) extends MemoryPool(lock) with Logging { diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala index bfeec47e3892..1b9edf9c43bd 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala @@ -27,7 +27,7 @@ import javax.annotation.concurrent.GuardedBy * to `Object` to avoid programming errors, since this object should only be used for * synchronization purposes. */ -abstract class MemoryPool(lock: Object) { +private[memory] abstract class MemoryPool(lock: Object) { @GuardedBy("lock") private[this] var _poolSize: Long = 0 diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index 6a322eabf81e..fc4f0357e9f1 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.{MemoryStore, BlockStatus, BlockId} * * @param lock a [[MemoryManager]] instance to synchronize on */ -class StorageMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { +private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { @GuardedBy("lock") private[this] var _memoryUsed: Long = 0L diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 48b4e23433e4..0f1ea9ab39c0 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -49,7 +49,7 @@ import org.apache.spark.storage.{BlockStatus, BlockId} private[spark] class UnifiedMemoryManager private[memory] ( conf: SparkConf, val maxMemory: Long, - private val storageRegionSize: Long, + storageRegionSize: Long, numCores: Int) extends MemoryManager( conf, From 39594894232e0b70c5ca8b0df137da0d61223fd5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Dec 2015 15:58:35 -0800 Subject: [PATCH 1707/1824] [SPARK-12069][SQL] Update documentation with Datasets Author: Michael Armbrust Closes #10060 from marmbrus/docs. --- docs/_layouts/global.html | 2 +- docs/index.md | 2 +- docs/sql-programming-guide.md | 268 +++++++++++------- .../scala/org/apache/spark/sql/Encoder.scala | 48 +++- .../scala/org/apache/spark/sql/Column.scala | 21 +- 5 files changed, 237 insertions(+), 104 deletions(-) diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 1b09e2221e17..0b5b0cd48af6 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -71,7 +71,7 @@
  • Spark Programming Guide
  • Spark Streaming
  • -
  • DataFrames and SQL
  • +
  • DataFrames, Datasets and SQL
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • Bagel (Pregel on Spark)
  • diff --git a/docs/index.md b/docs/index.md index f1d9e012c6cf..ae26f97c86c2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -87,7 +87,7 @@ options for deployment: in all supported languages (Scala, Java, Python, R) * Modules built on Spark: * [Spark Streaming](streaming-programming-guide.html): processing real-time data streams - * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries + * [Spark SQL, Datasets, and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7b1d97baa382..9f87accd30f4 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Spark SQL and DataFrame Guide +displayTitle: Spark SQL, DataFrames and Datasets Guide title: Spark SQL and DataFrames --- @@ -9,18 +9,51 @@ title: Spark SQL and DataFrames # Overview -Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +Spark SQL is a Spark module for structured data processing. Unlike the basic Spark RDD API, the interfaces provided +by Spark SQL provide Spark with more information about the structure of both the data and the computation being performed. Internally, +Spark SQL uses this extra information to perform extra optimizations. There are several ways to +interact with Spark SQL including SQL, the DataFrames API and the Datasets API. When computing a result +the same execution engine is used, independent of which API/language you are using to express the +computation. This unification means that developers can easily switch back and forth between the +various APIs based on which provides the most natural way to express a given transformation. -Spark SQL can also be used to read data from an existing Hive installation. For more on how to configure this feature, please refer to the [Hive Tables](#hive-tables) section. +All of the examples on this page use sample data included in the Spark distribution and can be run in +the `spark-shell`, `pyspark` shell, or `sparkR` shell. -# DataFrames +## SQL -A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. +One use of Spark SQL is to execute SQL queries written using either a basic SQL syntax or HiveQL. +Spark SQL can also be used to read data from an existing Hive installation. For more on how to +configure this feature, please refer to the [Hive Tables](#hive-tables) section. When running +SQL from within another programming language the results will be returned as a [DataFrame](#DataFrames). +You can also interact with the SQL interface using the [command-line](#running-the-spark-sql-cli) +or over [JDBC/ODBC](#running-the-thrift-jdbcodbc-server). -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). +## DataFrames -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R/Python, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of [sources](#data-sources) such +as: structured data files, tables in Hive, external databases, or existing RDDs. +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), +[Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), +[Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). + +## Datasets + +A Dataset is a new experimental interface added in Spark 1.6 that tries to provide the benefits of +RDDs (strong typing, ability to use powerful lambda functions) with the benefits of Spark SQL's +optimized execution engine. A Dataset can be [constructed](#creating-datasets) from JVM objects and then manipulated +using functional transformations (map, flatMap, filter, etc.). + +The unified Dataset API can be used both in [Scala](api/scala/index.html#org.apache.spark.sql.Dataset) and +[Java](api/java/index.html?org/apache/spark/sql/Dataset.html). Python does not yet have support for +the Dataset API, but due to its dynamic nature many of the benefits are already available (i.e. you can +access the field of a row by name naturally `row.columnName`). Full python support will be added +in a future release. + +# Getting Started ## Starting Point: SQLContext @@ -29,7 +62,7 @@ All of the examples on this page use sample data included in the Spark distribut The entry point into all functionality in Spark SQL is the [`SQLContext`](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight scala %} val sc: SparkContext // An existing SparkContext. @@ -45,7 +78,7 @@ import sqlContext.implicits._ The entry point into all functionality in Spark SQL is the [`SQLContext`](api/java/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. +descendants. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight java %} JavaSparkContext sc = ...; // An existing JavaSparkContext. @@ -58,7 +91,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); The entry point into all relational functionality in Spark is the [`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one -of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight python %} from pyspark.sql import SQLContext @@ -70,7 +103,7 @@ sqlContext = SQLContext(sc)
    The entry point into all relational functionality in Spark is the -`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. {% highlight r %} sqlContext <- sparkRSQL.init(sc) @@ -82,18 +115,18 @@ sqlContext <- sparkRSQL.init(sc) In addition to the basic `SQLContext`, you can also create a `HiveContext`, which provides a superset of the functionality provided by the basic `SQLContext`. Additional features include the ability to write queries using the more complete HiveQL parser, access to Hive UDFs, and the -ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an +ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an existing Hive setup, and all of the data sources available to a `SQLContext` are still available. `HiveContext` is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using `HiveContext` -is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up +Spark build. If these dependencies are not a problem for your application then using `HiveContext` +is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up to feature parity with a `HiveContext`. The specific variant of SQL that is used to parse queries can also be selected using the -`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on -a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect -available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the -default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, +`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on +a `SQLContext` or by using a `SET key=value` command in SQL. For a `SQLContext`, the only dialect +available is "sql" which uses a simple SQL parser provided by Spark SQL. In a `HiveContext`, the +default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, this is recommended for most use cases. @@ -215,7 +248,7 @@ df.groupBy("age").count().show() For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$).
    @@ -270,7 +303,7 @@ df.groupBy("age").count().show(); For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/java/org/apache/spark/sql/DataFrame.html). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). @@ -331,7 +364,7 @@ df.groupBy("age").count().show() For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/python/pyspark.sql.html#pyspark.sql.DataFrame). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). @@ -385,7 +418,7 @@ showDF(count(groupBy(df, "age"))) For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html). @@ -398,14 +431,14 @@ The `sql` function on a `SQLContext` enables applications to run SQL queries pro
    {% highlight scala %} -val sqlContext = ... // An existing SQLContext +val sqlContext = ... // An existing SQLContext val df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    {% highlight java %} -SQLContext sqlContext = ... // An existing SQLContext +SQLContext sqlContext = ... // An existing SQLContext DataFrame df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
    @@ -428,15 +461,54 @@ df <- sql(sqlContext, "SELECT * FROM table")
    +## Creating Datasets + +Datasets are similar to RDDs, however, instead of using Java Serialization or Kryo they use +a specialized [Encoder](api/scala/index.html#org.apache.spark.sql.Encoder) to serialize the objects +for processing or transmitting over the network. While both encoders and standard serialization are +responsible for turning an object into bytes, encoders are code generated dynamically and use a format +that allows Spark to perform many operations like filtering, sorting and hashing without deserializing +the bytes back into an object. + +
    +
    + +{% highlight scala %} +// Encoders for most common types are automatically provided by importing sqlContext.implicits._ +val ds = Seq(1, 2, 3).toDS() +ds.map(_ + 1).collect() // Returns: Array(2, 3, 4) + +// Encoders are also created for case classes. +case class Person(name: String, age: Long) +val ds = Seq(Person("Andy", 32)).toDS() + +// DataFrames can be converted to a Dataset by providing a class. Mapping will be done by name. +val path = "examples/src/main/resources/people.json" +val people = sqlContext.read.json(path).as[Person] + +{% endhighlight %} + +
    + +
    + +{% highlight java %} +JavaSparkContext sc = ...; // An existing JavaSparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); +{% endhighlight %} + +
    +
    + ## Interoperating with RDDs -Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first -method uses reflection to infer the schema of an RDD that contains specific types of objects. This +Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first +method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating DataFrames is through a programmatic interface that allows you to -construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows +construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows you to construct DataFrames when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection @@ -445,11 +517,11 @@ you to construct DataFrames when the columns and their types are not known until
    The Scala interface for Spark SQL supports automatically converting an RDD containing case classes -to a DataFrame. The case class -defines the schema of the table. The names of the arguments to the case class are read using +to a DataFrame. The case class +defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be -registered as a table. Tables can be used in subsequent SQL statements. +registered as a table. Tables can be used in subsequent SQL statements. {% highlight scala %} // sc is an existing SparkContext. @@ -486,9 +558,9 @@ teenagers.map(_.getValuesMap[Any](List("name", "age"))).collect().foreach(printl
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. +into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain -nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a +nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. {% highlight java %} @@ -559,9 +631,9 @@ List teenagerNames = teenagers.javaRDD().map(new Function()
    -Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of +Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, -and the types are inferred by looking at the first row. Since we currently only look at the first +and the types are inferred by looking at the first row. Since we currently only look at the first row, it is important that there is no missing data in the first row of the RDD. In future versions we plan to more completely infer the schema by looking at more data, similar to the inference that is performed on JSON files. @@ -780,7 +852,7 @@ for name in names.collect(): Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a DataFrame as a table allows you to run SQL queries over its data. This section +Registering a DataFrame as a table allows you to run SQL queries over its data. This section describes the general methods for loading and saving data using the Spark Data Sources and then goes into specific options that are available for the built-in data sources. @@ -834,9 +906,9 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet") ### Manually Specifying Options You can also manually specify the data source that will be used along with any extra options -that you would like to pass to the data source. Data sources are specified by their fully qualified +that you would like to pass to the data source. Data sources are specified by their fully qualified name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short -names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
    @@ -923,8 +995,8 @@ df <- sql(sqlContext, "SELECT * FROM parquet.`examples/src/main/resources/users. ### Save Modes Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if -present. It is important to realize that these save modes do not utilize any locking and are not -atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +present. It is important to realize that these save modes do not utilize any locking and are not +atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the new data. @@ -960,7 +1032,7 @@ new data.
    Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
    @@ -968,14 +1040,14 @@ new data. ### Saving to Persistent Tables When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the -`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the -contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables +`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the +contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables will still exist even after your Spark program has restarted, as long as you maintain your connection -to the same metastore. A DataFrame for a persistent table can be created by calling the `table` +to the same metastore. A DataFrame for a persistent table can be created by calling the `table` method on a `SQLContext` with the name of the table. By default `saveAsTable` will create a "managed table", meaning that the location of the data will -be controlled by the metastore. Managed tables will also have their data deleted automatically +be controlled by the metastore. Managed tables will also have their data deleted automatically when a table is dropped. ## Parquet Files @@ -1003,7 +1075,7 @@ val people: RDD[Person] = ... // An RDD of case class objects, from the previous // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. people.write.parquet("people.parquet") -// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. val parquetFile = sqlContext.read.parquet("people.parquet") @@ -1025,7 +1097,7 @@ DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write().parquet("people.parquet"); -// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); @@ -1051,7 +1123,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. schemaPeople.write.parquet("people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile = sqlContext.read.parquet("people.parquet") @@ -1075,7 +1147,7 @@ schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. saveAsParquetFile(schemaPeople, "people.parquet") -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. parquetFile <- parquetFile(sqlContext, "people.parquet") @@ -1110,10 +1182,10 @@ SELECT * FROM parquetTable ### Partition Discovery -Table partitioning is a common optimization approach used in systems like Hive. In a partitioned +Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in -the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For example, we can store all our previously used +the path of each partition directory. The Parquet data source is now able to discover and infer +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -1155,7 +1227,7 @@ root {% endhighlight %} -Notice that the data types of the partitioning columns are automatically inferred. Currently, +Notice that the data types of the partitioning columns are automatically inferred. Currently, numeric data types and string type are supported. Sometimes users may not want to automatically infer the data types of the partitioning columns. For these use cases, the automatic type inference can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to @@ -1164,13 +1236,13 @@ can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, w ### Schema Merging -Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with -a simple schema, and gradually add more columns to the schema as needed. In this way, users may end -up with multiple Parquet files with different but mutually compatible schemas. The Parquet data +Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with +a simple schema, and gradually add more columns to the schema as needed. In this way, users may end +up with multiple Parquet files with different but mutually compatible schemas. The Parquet data source is now able to automatically detect this case and merge schemas of all these files. Since schema merging is a relatively expensive operation, and is not a necessity in most cases, we -turned it off by default starting from 1.5.0. You may enable it by +turned it off by default starting from 1.5.0. You may enable it by 1. setting data source option `mergeSchema` to `true` when reading Parquet files (as shown in the examples below), or @@ -1284,10 +1356,10 @@ processing. 1. Hive considers all columns nullable, while nullability in Parquet is significant Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a -Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: 1. Fields that have the same name in both schema must have the same data type regardless of - nullability. The reconciled field should have the data type of the Parquet side, so that + nullability. The reconciled field should have the data type of the Parquet side, so that nullability is respected. 1. The reconciled schema contains exactly those fields defined in Hive metastore schema. @@ -1298,8 +1370,8 @@ Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation r #### Metadata Refreshing -Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table -conversion is enabled, metadata of those converted tables are also cached. If these tables are +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are updated by Hive or other external tools, you need to refresh them manually to ensure consistent metadata. @@ -1362,7 +1434,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.int96AsTimestamp true - Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This + Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. @@ -1400,7 +1472,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

    The output committer class used by Parquet. The specified class needs to be a subclass of - org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a + org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a subclass of org.apache.parquet.hadoop.ParquetOutputCommitter.

    @@ -1628,7 +1700,7 @@ YARN cluster. The convenient way to do this is adding them through the `--jars` When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can still create a `HiveContext`. When not configured by the +not have an existing Hive deployment can still create a `HiveContext`. When not configured by the hive-site.xml, the context automatically creates `metastore_db` in the current directory and creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`. Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts @@ -1738,10 +1810,10 @@ The following options can be used to configure the version of Hive that is used enabled. When this option is chosen, spark.sql.hive.metastore.version must be either 1.2.1 or not defined.

  • maven
  • - Use Hive jars of specified version downloaded from Maven repositories. This configuration + Use Hive jars of specified version downloaded from Maven repositories. This configuration is not generally recommended for production deployments. -
  • A classpath in the standard format for the JVM. This classpath must include all of Hive - and its dependencies, including the correct version of Hadoop. These jars only need to be +
  • A classpath in the standard format for the JVM. This classpath must include all of Hive + and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure they are packaged with you application.
  • @@ -1776,7 +1848,7 @@ The following options can be used to configure the version of Hive that is used ## JDBC To Other Databases -Spark SQL also includes a data source that can read data from other databases using JDBC. This +Spark SQL also includes a data source that can read data from other databases using JDBC. This functionality should be preferred over using [JdbcRDD](api/scala/index.html#org.apache.spark.rdd.JdbcRDD). This is because the results are returned as a DataFrame and they can easily be processed in Spark SQL or joined with other data sources. @@ -1786,7 +1858,7 @@ provide a ClassTag. run queries using Spark SQL). To get started you will need to include the JDBC driver for you particular database on the -spark classpath. For example, to connect to postgres from the Spark Shell you would run the +spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: {% highlight bash %} @@ -1794,7 +1866,7 @@ SPARK_CLASSPATH=postgresql-9.3-1102-jdbc41.jar bin/spark-shell {% endhighlight %} Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using -the Data Sources API. The following options are supported: +the Data Sources API. The following options are supported: @@ -1807,8 +1879,8 @@ the Data Sources API. The following options are supported: @@ -1816,7 +1888,7 @@ the Data Sources API. The following options are supported: @@ -1825,7 +1897,7 @@ the Data Sources API. The following options are supported: @@ -1947,7 +2019,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ ## Other Configuration Options -The following options can also be used to tune the performance of query execution. It is possible +The following options can also be used to tune the performance of query execution. It is possible that these options will be deprecated in future release as more optimizations are performed automatically.
    Property NameMeaning
    dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a + The JDBC table that should be read. Note that anything that is valid in a FROM clause of + a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses.
    driver - The class name of the JDBC driver needed to connect to this URL. This class will be loaded + The class name of the JDBC driver needed to connect to this URL. This class will be loaded on the master and workers before running an JDBC commands to allow the driver to register itself with the JDBC subsystem.
    partitionColumn, lowerBound, upperBound, numPartitions - These options must all be specified if any of them is specified. They describe how to + These options must all be specified if any of them is specified. They describe how to partition the table when reading in parallel from multiple workers. partitionColumn must be a numeric column from the table in question. Notice that lowerBound and upperBound are just used to decide the @@ -1938,7 +2010,7 @@ Configuration of in-memory caching can be done using the `setConf` method on `SQ spark.sql.inMemoryColumnarStorage.batchSize 10000 - Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization + Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data.
    @@ -1957,7 +2029,7 @@ that these options will be deprecated in future release as more optimizations ar @@ -1995,8 +2067,8 @@ To start the JDBC/ODBC server, run the following in the Spark directory: ./sbin/start-thriftserver.sh This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to -specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of -all available options. By default, the server listens on localhost:10000. You may override this +specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of +all available options. By default, the server listens on localhost:10000. You may override this behaviour via either environment variables, i.e.: {% highlight bash %} @@ -2062,10 +2134,10 @@ options. ## Upgrading From Spark SQL 1.5 to 1.6 - - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC - connection owns a copy of their own SQL configuration and temporary function registry. Cached - tables are still shared though. If you prefer to run the Thrift server in the old single-session - mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add + - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + connection owns a copy of their own SQL configuration and temporary function registry. Cached + tables are still shared though. If you prefer to run the Thrift server in the old single-session + mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add this option to `spark-defaults.conf`, or pass it to `start-thriftserver.sh` via `--conf`: {% highlight bash %} @@ -2077,20 +2149,20 @@ options. ## Upgrading From Spark SQL 1.4 to 1.5 - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with - code generation for expression evaluation. These features can both be disabled by setting + code generation for expression evaluation. These features can both be disabled by setting `spark.sql.tungsten.enabled` to `false`. - - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting `spark.sql.parquet.mergeSchema` to `true`. - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or - access nested values. For example `df['table.column.nestedField']`. However, this means that if - your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). + access nested values. For example `df['table.column.nestedField']`. However, this means that if + your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum - precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now + precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`. - Timestamps are now stored at a precision of 1us, rather than 1ns - - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains + - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains unchanged. - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe @@ -2183,38 +2255,38 @@ sqlContext.setConf("spark.sql.retainGroupColumns", "false") ## Upgrading from Spark SQL 1.0-1.2 to 1.3 In Spark 1.3 we removed the "Alpha" label from Spark SQL and as part of this did a cleanup of the -available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other -releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked +available APIs. From Spark 1.3 onwards, Spark SQL will provide binary compatibility with other +releases in the 1.X series. This compatibility guarantee excludes APIs that are explicitly marked as unstable (i.e., DeveloperAPI or Experimental). #### Rename of SchemaRDD to DataFrame The largest change that users will notice when upgrading to Spark SQL 1.3 is that `SchemaRDD` has -been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD +been renamed to `DataFrame`. This is primarily because DataFrames no longer inherit from RDD directly, but instead provide most of the functionality that RDDs provide though their own -implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. +implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for -some use cases. It is still recommended that users update their code to use `DataFrame` instead. +some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. #### Unification of the Java and Scala APIs Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) -that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users +of either language should use `SQLContext` and `DataFrame`. In general theses classes try to use types that are usable from both languages (i.e. `Array` instead of language specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. #### Isolation of Implicit Conversions and Removal of dsl Package (Scala-only) Many of the code examples prior to Spark 1.3 started with `import sqlContext._`, which brought -all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit +all of the functions from sqlContext into scope. In Spark 1.3 we have isolated the implicit conversions for converting `RDD`s into `DataFrame`s into an object inside of the `SQLContext`. Users should now write `import sqlContext.implicits._`. @@ -2222,7 +2294,7 @@ Additionally, the implicit conversions now only augment RDDs that are composed o case classes or tuples) with a method `toDF`, instead of applying automatically. When using function inside of the DSL (now replaced with the `DataFrame` API) users used to import -`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: +`org.apache.spark.sql.catalyst.dsl`. Instead the public dataframe functions API should be used: `import org.apache.spark.sql.functions._`. #### Removal of the type aliases in org.apache.spark.sql for DataType (Scala-only) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 3ca5ade7f30f..bb0fdc4c3d83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -19,20 +19,60 @@ package org.apache.spark.sql import java.lang.reflect.Modifier +import scala.annotation.implicitNotFound import scala.reflect.{ClassTag, classTag} +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** + * :: Experimental :: * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * - * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking - * and reuse internal buffers to improve performance. + * == Scala == + * Encoders are generally created automatically through implicits from a `SQLContext`. + * + * {{{ + * import sqlContext.implicits._ + * + * val ds = Seq(1, 2, 3).toDS() // implicitly provided (sqlContext.implicits.newIntEncoder) + * }}} + * + * == Java == + * Encoders are specified by calling static methods on [[Encoders]]. + * + * {{{ + * List data = Arrays.asList("abc", "abc", "xyz"); + * Dataset ds = context.createDataset(data, Encoders.STRING()); + * }}} + * + * Encoders can be composed into tuples: + * + * {{{ + * Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); + * List> data2 = Arrays.asList(new scala.Tuple2(1, "a"); + * Dataset> ds2 = context.createDataset(data2, encoder2); + * }}} + * + * Or constructed from Java Beans: + * + * {{{ + * Encoders.bean(MyClass.class); + * }}} + * + * == Implementation == + * - Encoders are not required to be thread-safe and thus they do not need to use locks to guard + * against concurrent access if they reuse internal buffers to improve performance. * * @since 1.6.0 */ +@Experimental +@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + + "(Int, String, etc) and Product types (case classes) are supported by importing " + + "sqlContext.implicits._ Support for serializing other types will be added in future " + + "releases.") trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ @@ -43,10 +83,12 @@ trait Encoder[T] extends Serializable { } /** - * Methods for creating encoders. + * :: Experimental :: + * Methods for creating an [[Encoder]]. * * @since 1.6.0 */ +@Experimental object Encoders { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad6af481fadc..d641fcac1c8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -73,7 +73,26 @@ class TypedColumn[-T, U]( /** * :: Experimental :: - * A column in a [[DataFrame]]. + * A column that will be computed based on the data in a [[DataFrame]]. + * + * A new column is constructed based on the input columns present in a dataframe: + * + * {{{ + * df("columnName") // On a specific DataFrame. + * col("columnName") // A generic column no yet associcated with a DataFrame. + * col("columnName.field") // Extracting a struct field + * col("`a.column.with.dots`") // Escape `.` in column names. + * $"columnName" // Scala short hand for a named column. + * expr("a + 1") // A column that is constructed from a parsed SQL Expression. + * lit("1") // A column that produces a literal (constant) value. + * }}} + * + * [[Column]] objects can be composed to form complex expressions: + * + * {{{ + * $"a" + 1 + * $"a" === $"b" + * }}} * * @groupname java_expr_ops Java-specific expression operators * @groupname expr_ops Expression operators From 765c67f5f2e0b1367e37883f662d313661e3a0d9 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 8 Dec 2015 18:40:21 -0800 Subject: [PATCH 1708/1824] [SPARK-8517][ML][DOC] Reorganizes the spark.ml user guide This PR moves pieces of the spark.ml user guide to reflect suggestions in SPARK-8517. It does not introduce new content, as requested. screen shot 2015-12-08 at 11 36 00 am Author: Timothy Hunter Closes #10207 from thunterdb/spark-8517. --- docs/_data/menu-ml.yaml | 18 +- docs/ml-advanced.md | 13 + docs/ml-ann.md | 62 -- docs/ml-classification-regression.md | 775 ++++++++++++++++++++++ docs/ml-clustering.md | 5 + docs/ml-features.md | 4 +- docs/ml-intro.md | 941 +++++++++++++++++++++++++++ docs/mllib-guide.md | 15 +- 8 files changed, 1752 insertions(+), 81 deletions(-) create mode 100644 docs/ml-advanced.md delete mode 100644 docs/ml-ann.md create mode 100644 docs/ml-classification-regression.md create mode 100644 docs/ml-intro.md diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index dff3d33bf4ed..fe37d0573e46 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -1,10 +1,10 @@ -- text: Feature extraction, transformation, and selection +- text: "Overview: estimators, transformers and pipelines" + url: ml-intro.html +- text: Extracting, transforming and selecting features url: ml-features.html -- text: Decision trees for classification and regression - url: ml-decision-tree.html -- text: Ensembles - url: ml-ensembles.html -- text: Linear methods with elastic-net regularization - url: ml-linear-methods.html -- text: Multilayer perceptron classifier - url: ml-ann.html +- text: Classification and Regression + url: ml-classification-regression.html +- text: Clustering + url: ml-clustering.html +- text: Advanced topics + url: ml-advanced.html diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md new file mode 100644 index 000000000000..b005633e56c1 --- /dev/null +++ b/docs/ml-advanced.md @@ -0,0 +1,13 @@ +--- +layout: global +title: Advanced topics - spark.ml +displayTitle: Advanced topics +--- + +# Optimization of linear methods + +The optimization algorithm underlying the implementation is called +[Orthant-Wise Limited-memory +QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 +regularization and elastic net. diff --git a/docs/ml-ann.md b/docs/ml-ann.md deleted file mode 100644 index 6e763e8f4156..000000000000 --- a/docs/ml-ann.md +++ /dev/null @@ -1,62 +0,0 @@ ---- -layout: global -title: Multilayer perceptron classifier - ML -displayTitle: ML - Multilayer perceptron classifier ---- - - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). -MLPC consists of multiple layers of nodes. -Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs -by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. -It can be written in matrix form for MLPC with `$K+1$` layers as follows: -`\[ -\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) -\]` -Nodes in intermediate layers use sigmoid (logistic) function: -`\[ -\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} -\]` -Nodes in the output layer use softmax function: -`\[ -\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} -\]` -The number of nodes `$N$` in the output layer corresponds to the number of classes. - -MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. - -**Examples** - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java %} -
    - -
    -{% include_example python/ml/multilayer_perceptron_classification.py %} -
    - -
    diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md new file mode 100644 index 000000000000..3663ffee3275 --- /dev/null +++ b/docs/ml-classification-regression.md @@ -0,0 +1,775 @@ +--- +layout: global +title: Classification and regression - spark.ml +displayTitle: Classification and regression in spark.ml +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +In MLlib, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods in mllib](mllib-linear-methods.html) for +details. In `spark.ml`, we also include Pipelines API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: +`\[ +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 +\]` +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. + + +# Classification + +## Logistic regression + +Logistic regression is a popular method to predict a binary response. It is a special case of [Generalized Linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) that predicts the probability of the outcome. +For more background and more details about the implementation, refer to the documentation of the [logistic regression in `spark.mllib`](mllib-linear-methods.html#logistic-regression). + + > The current implementation of logistic regression in `spark.ml` only supports binary classes. Support for multiclass regression will be added in the future. + +**Example** + +The following example shows how to train a logistic regression model +with elastic net regularization. `elasticNetParam` corresponds to +$\alpha$ and `regParam` corresponds to $\lambda$. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %} +
    + +
    +{% include_example python/ml/logistic_regression_with_elastic_net.py %} +
    + +
    + +The `spark.ml` implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as `Dataframe` in +`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +only available on the driver. + +
    + +
    + +[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) +provides a summary for a +[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %} +
    + +
    +[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) +provides a summary for a +[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} +
    + + +
    +Logistic regression model summary is not yet supported in Python. +
    + +
    + + +## Decision tree classifier + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala %} + +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java %} + +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). + +{% include_example python/ml/decision_tree_classification_example.py %} + +
    + +
    + +## Random forest classifier + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details. + +{% include_example python/ml/random_forest_classifier_example.py %} +
    +
    + +## Gradient-boosted tree classifier + +Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details. + +{% include_example python/ml/gradient_boosted_tree_classifier_example.py %} +
    +
    + +## Multilayer perceptron classifier + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). +MLPC consists of multiple layers of nodes. +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs +by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. +It can be written in matrix form for MLPC with `$K+1$` layers as follows: +`\[ +\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) +\]` +Nodes in intermediate layers use sigmoid (logistic) function: +`\[ +\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} +\]` +Nodes in the output layer use softmax function: +`\[ +\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} +\]` +The number of nodes `$N$` in the output layer corresponds to the number of classes. + +MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java %} +
    + +
    +{% include_example python/ml/multilayer_perceptron_classification.py %} +
    + +
    + + +## One-vs-Rest classifier (a.k.a. One-vs-All) + +[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." + +`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. + +Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. + +**Example** + +The example below demonstrates how to load the +[Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. + +{% include_example scala/org/apache/spark/examples/ml/OneVsRestExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaOneVsRestExample.java %} +
    +
    + + +# Regression + +## Linear regression + +The interface for working with linear regression models and model +summaries is similar to the logistic regression case. + +**Example** + +The following +example demonstrates training an elastic net regularized linear +regression model and extracting model summary statistics. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %} +
    + +
    + +{% include_example python/ml/linear_regression_with_elastic_net.py %} +
    + +
    + + +## Decision tree regression + +Decision trees are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. + +
    +
    + +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). + +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala %} +
    + +
    + +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). + +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java %} +
    + +
    + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). + +{% include_example python/ml/decision_tree_regression_example.py %} +
    + +
    + + +## Random forest regression + +Random forests are a popular family of classification and regression methods. +More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). + +**Example** + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details. + +{% include_example python/ml/random_forest_regressor_example.py %} +
    +
    + +## Gradient-boosted tree regression + +Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. +More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). + +**Example** + +Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not +be true in general. + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details. + +{% include_example python/ml/gradient_boosted_tree_regressor_example.py %} +
    +
    + + +## Survival regression + + +In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) +model which is a parametric survival regression model for censored data. +It describes a model for the log of survival time, so it's often called +log-linear model for survival analysis. Different from +[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model +designed for the same purpose, the AFT model is more easily to parallelize +because each instance contribute to the objective function independently. + +Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of +subjects i = 1, ..., n, with possible right-censoring, +the likelihood function under the AFT model is given as: +`\[ +L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} +\]` +Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. +Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function +assumes the form: +`\[ +\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] +\]` +Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, +and $f_{0}(\epsilon_{i})$ is corresponding density function. + +The most commonly used AFT model is based on the Weibull distribution of the survival time. +The Weibull distribution for lifetime corresponding to extreme value distribution for +log of the lifetime, and the $S_{0}(\epsilon)$ function is: +`\[ +S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) +\]` +the $f_{0}(\epsilon_{i})$ function is: +`\[ +f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) +\]` +The log-likelihood function for AFT model with Weibull distribution of lifetime is: +`\[ +\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] +\]` +Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, +the loss function we use to optimize is $-\iota(\beta,\sigma)$. +The gradient functions for $\beta$ and $\log\sigma$ respectively are: +`\[ +\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} +\]` +`\[ +\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] +\]` + +The AFT model can be formulated as a convex optimization problem, +i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ +that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +The optimization algorithm underlying the implementation is L-BFGS. +The implementation matches the result from R's survival function +[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) + +**Example** + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} +
    + +
    +{% include_example python/ml/aft_survival_regression.py %} +
    + +
    + + + +# Decision trees + +[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning) +and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical features, extend to the multiclass classification setting, do not require +feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble +algorithms such as random forests and boosting are among the top performers for classification and +regression tasks. + +MLlib supports decision trees for binary and multiclass classification and for regression, +using both continuous and categorical features. The implementation partitions data by rows, +allowing distributed training with millions or even billions of instances. + +Users can find more information about the decision tree algorithm in the [MLlib Decision Tree guide](mllib-decision-tree.html). +The main differences between this API and the [original MLlib Decision Tree API](mllib-decision-tree.html) are: + +* support for ML Pipelines +* separation of Decision Trees for classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features + + +The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities). + +Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the [Tree ensembles section](#tree-ensembles). + +## Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +### Input Columns + +
    10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when - performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently + performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.
    + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + + +# Tree Ensembles + +The Pipelines API supports two major tree ensemble algorithms: [Random Forests](http://en.wikipedia.org/wiki/Random_forest) and [Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting). +Both use [MLlib decision trees](ml-decision-tree.html) as their base models. + +Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). In this section, we demonstrate the Pipelines API for ensembles. + +The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: + +* support for ML Pipelines +* separation of classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features +* a bit more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification. + +## Random Forests + +[Random forests](http://en.wikipedia.org/wiki/Random_forest) +are ensembles of [decision trees](ml-decision-tree.html). +Random forests combine many decision trees in order to reduce the risk of overfitting. +MLlib supports random forests for binary and multiclass classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
    probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
    + + + +## Gradient-Boosted Trees (GBTs) + +[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) +are ensembles of [decision trees](ml-decision-tree.html). +GBTs iteratively train decision trees in order to minimize a loss function. +MLlib supports GBTs for binary classification and for regression, +using both continuous and categorical features. + +For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html). + +### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +#### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    labelColDouble"label"Label to predict
    featuresColVector"features"Feature vector
    + +Note that `GBTClassifier` currently only supports binary labels. + +#### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescriptionNotes
    predictionColDouble"prediction"Predicted label
    + +In the future, `GBTClassifier` will also output columns for `rawPrediction` and `probability`, just as `RandomForestClassifier` does. + diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index cfefb5dfbde9..697777714b05 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -6,6 +6,11 @@ displayTitle: ML - Clustering In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + ## Latent Dirichlet allocation (LDA) `LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, diff --git a/docs/ml-features.md b/docs/ml-features.md index e15c26836aff..55e401221917 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction, Transformation, and Selection - SparkML -displayTitle: ML - Features +title: Extracting, transforming and selecting features +displayTitle: Extracting, transforming and selecting features --- This section covers algorithms for working with features, roughly divided into these groups: diff --git a/docs/ml-intro.md b/docs/ml-intro.md new file mode 100644 index 000000000000..d95a66ba2356 --- /dev/null +++ b/docs/ml-intro.md @@ -0,0 +1,941 @@ +--- +layout: global +title: "Overview: estimators, transformers and pipelines - spark.ml" +displayTitle: "Overview: estimators, transformers and pipelines" +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of +[DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical +machine learning pipelines. +See the [algorithm guides](#algorithm-guides) section below for guides on sub-packages of +`spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. + +**Table of contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + + +# Main concepts in Pipelines + +Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple +algorithms into a single pipeline, or workflow. +This section covers the key concepts introduced by the Spark ML API, where the pipeline concept is +mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. + +* **[`DataFrame`](ml-guide.html#dataframe)**: Spark ML uses `DataFrame` from Spark SQL as an ML + dataset, which can hold a variety of data types. + E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. + +* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. +E.g., an ML model is a `Transformer` which transforms `DataFrame` with features into a `DataFrame` with predictions. + +* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. +E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. + +* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. + +* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. + +## DataFrame + +Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. +Spark ML adopts the `DataFrame` from Spark SQL in order to support a variety of data types. + +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. + +A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. + +Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." + +## Pipeline components + +### Transformers + +A `Transformer` is an abstraction that includes feature transformers and learned models. +Technically, a `Transformer` implements a method `transform()`, which converts one `DataFrame` into +another, generally by appending one or more columns. +For example: + +* A feature transformer might take a `DataFrame`, read a column (e.g., text), map it into a new + column (e.g., feature vectors), and output a new `DataFrame` with the mapped column appended. +* A learning model might take a `DataFrame`, read the column containing feature vectors, predict the + label for each feature vector, and output a new `DataFrame` with predicted labels appended as a + column. + +### Estimators + +An `Estimator` abstracts the concept of a learning algorithm or any algorithm that fits or trains on +data. +Technically, an `Estimator` implements a method `fit()`, which accepts a `DataFrame` and produces a +`Model`, which is a `Transformer`. +For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling +`fit()` trains a `LogisticRegressionModel`, which is a `Model` and hence a `Transformer`. + +### Properties of pipeline components + +`Transformer.transform()`s and `Estimator.fit()`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. + +Each instance of a `Transformer` or `Estimator` has a unique ID, which is useful in specifying parameters (discussed below). + +## Pipeline + +In machine learning, it is common to run a sequence of algorithms to process and learn from data. +E.g., a simple text document processing workflow might include several stages: + +* Split each document's text into words. +* Convert each document's words into a numerical feature vector. +* Learn a prediction model using the feature vectors and labels. + +Spark ML represents such a workflow as a `Pipeline`, which consists of a sequence of +`PipelineStage`s (`Transformer`s and `Estimator`s) to be run in a specific order. +We will use this simple workflow as a running example in this section. + +### How it works + +A `Pipeline` is specified as a sequence of stages, and each stage is either a `Transformer` or an `Estimator`. +These stages are run in order, and the input `DataFrame` is transformed as it passes through each stage. +For `Transformer` stages, the `transform()` method is called on the `DataFrame`. +For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the `DataFrame`. + +We illustrate this for the simple text document workflow. The figure below is for the *training time* usage of a `Pipeline`. + +

    + Spark ML Pipeline Example +

    + +Above, the top row represents a `Pipeline` with three stages. +The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). +The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. +The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw text documents and labels. +The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. +The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. +Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. +If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +method on the `DataFrame` before passing the `DataFrame` to the next stage. + +A `Pipeline` is an `Estimator`. +Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel`, which is a +`Transformer`. +This `PipelineModel` is used at *test time*; the figure below illustrates this usage. + +

    + Spark ML PipelineModel Example +

    + +In the figure above, the `PipelineModel` has the same number of stages as the original `Pipeline`, but all `Estimator`s in the original `Pipeline` have become `Transformer`s. +When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed +through the fitted pipeline in order. +Each stage's `transform()` method updates the dataset and passes it to the next stage. + +`Pipeline`s and `PipelineModel`s help to ensure that training and test data go through identical feature processing steps. + +### Details + +*DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. + +*Runtime checking*: Since `Pipeline`s can operate on `DataFrame`s with varied types, they cannot use +compile-time type checking. +`Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. +This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. + +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. + +## Parameters + +Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. + +A `Param` is a named parameter with self-contained documentation. +A `ParamMap` is a set of (parameter, value) pairs. + +There are two main ways to pass parameters to an algorithm: + +1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could + call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. + This API resembles the API used in `spark.mllib` package. +2. Pass a `ParamMap` to `fit()` or `transform()`. Any parameters in the `ParamMap` will override parameters previously specified via setter methods. + +Parameters belong to specific instances of `Estimator`s and `Transformer`s. +For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. +This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. + +# Code examples + +This section gives code examples illustrating the functionality discussed above. +For more info, please refer to the API documentation +([Scala](api/scala/index.html#org.apache.spark.ml.package), +[Java](api/java/org/apache/spark/ml/package-summary.html), +and [Python](api/python/pyspark.ml.html)). +Some Spark ML algorithms are wrappers for `spark.mllib` algorithms, and the +[MLlib programming guide](mllib-guide.html) has details on specific algorithms. + +## Example: Estimator, Transformer, and Param + +This example covers the concepts of `Estimator`, `Transformer`, and `Param`. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.sql.Row + +// Prepare training data from a list of (label, features) tuples. +val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) +)).toDF("label", "features") + +// Create a LogisticRegression instance. This instance is an Estimator. +val lr = new LogisticRegression() +// Print out the parameters, documentation, and any default values. +println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + +// We may set parameters using setter methods. +lr.setMaxIter(10) + .setRegParam(0.01) + +// Learn a LogisticRegression model. This uses the parameters stored in lr. +val model1 = lr.fit(training) +// Since model1 is a Model (i.e., a Transformer produced by an Estimator), +// we can view the parameters it used during fit(). +// This prints the parameter (name: value) pairs, where names are unique IDs for this +// LogisticRegression instance. +println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + +// We may alternatively specify parameters using a ParamMap, +// which supports several methods for specifying parameters. +val paramMap = ParamMap(lr.maxIter -> 20) + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + +// One can also combine ParamMaps. +val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name +val paramMapCombined = paramMap ++ paramMap2 + +// Now learn a new model using the paramMapCombined parameters. +// paramMapCombined overrides all parameters set earlier via lr.set* methods. +val model2 = lr.fit(training, paramMapCombined) +println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + +// Prepare test data. +val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) +)).toDF("label", "features") + +// Make predictions on test data using the Transformer.transform() method. +// LogisticRegression.transform will only use the 'features' column. +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +model2.transform(test) + .select("features", "label", "myProbability", "prediction") + .collect() + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println(s"($features, $label) -> prob=$prob, prediction=$prediction") + } + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; + +// Prepare training data. +// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans +// into DataFrames, where it uses the bean metadata to infer the schema. +DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) +), LabeledPoint.class); + +// Create a LogisticRegression instance. This instance is an Estimator. +LogisticRegression lr = new LogisticRegression(); +// Print out the parameters, documentation, and any default values. +System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); + +// We may set parameters using setter methods. +lr.setMaxIter(10) + .setRegParam(0.01); + +// Learn a LogisticRegression model. This uses the parameters stored in lr. +LogisticRegressionModel model1 = lr.fit(training); +// Since model1 is a Model (i.e., a Transformer produced by an Estimator), +// we can view the parameters it used during fit(). +// This prints the parameter (name: value) pairs, where names are unique IDs for this +// LogisticRegression instance. +System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); + +// We may alternatively specify parameters using a ParamMap. +ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + +// One can also combine ParamMaps. +ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name +ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); + +// Now learn a new model using the paramMapCombined parameters. +// paramMapCombined overrides all parameters set earlier via lr.set* methods. +LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); +System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); + +// Prepare test documents. +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) +), LabeledPoint.class); + +// Make predictions on test documents using the Transformer.transform() method. +// LogisticRegression.transform will only use the 'features' column. +// Note that model2.transform() outputs a 'myProbability' column instead of the usual +// 'probability' column since we renamed the lr.probabilityCol parameter previously. +DataFrame results = model2.transform(test); +for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + + ", prediction=" + r.get(3)); +} + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.param import Param, Params + +# Prepare training data from a list of (label, features) tuples. +training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) + +# Create a LogisticRegression instance. This instance is an Estimator. +lr = LogisticRegression(maxIter=10, regParam=0.01) +# Print out the parameters, documentation, and any default values. +print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" + +# Learn a LogisticRegression model. This uses the parameters stored in lr. +model1 = lr.fit(training) + +# Since model1 is a Model (i.e., a transformer produced by an Estimator), +# we can view the parameters it used during fit(). +# This prints the parameter (name: value) pairs, where names are unique IDs for this +# LogisticRegression instance. +print "Model 1 was fit using parameters: " +print model1.extractParamMap() + +# We may alternatively specify parameters using a Python dictionary as a paramMap +paramMap = {lr.maxIter: 20} +paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter. +paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params. + +# You can combine paramMaps, which are python dictionaries. +paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name +paramMapCombined = paramMap.copy() +paramMapCombined.update(paramMap2) + +# Now learn a new model using the paramMapCombined parameters. +# paramMapCombined overrides all parameters set earlier via lr.set* methods. +model2 = lr.fit(training, paramMapCombined) +print "Model 2 was fit using parameters: " +print model2.extractParamMap() + +# Prepare test data +test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) + +# Make predictions on test data using the Transformer.transform() method. +# LogisticRegression.transform will only use the 'features' column. +# Note that model2.transform() outputs a "myProbability" column instead of the usual +# 'probability' column since we renamed the lr.probabilityCol parameter previously. +prediction = model2.transform(test) +selected = prediction.select("features", "label", "myProbability", "prediction") +for row in selected.collect(): + print row + +{% endhighlight %} +
    + +
    + +## Example: Pipeline + +This example follows the simple text document `Pipeline` illustrated in the figures above. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row + +// Prepare training documents from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) +)).toDF("id", "text", "label") + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") +val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") +val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) +val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + +// Fit the pipeline to training documents. +val model = pipeline.fit(training) + +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") + +// Make predictions on test documents. +model.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; + +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} + +// Prepare training documents, which are labeled. +DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0) +), LabeledDocument.class); + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); +HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); +LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + +// Fit the pipeline to training documents. +PipelineModel model = pipeline.fit(training); + +// Prepare test documents, which are unlabeled. +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop") +), Document.class); + +// Make predictions on test documents. +DataFrame predictions = model.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); +} + +{% endhighlight %} +
    + +
    +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.sql import Row + +# Prepare training documents from a list of (id, text, label) tuples. +LabeledDocument = Row("id", "text", "label") +training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) + +# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. +tokenizer = Tokenizer(inputCol="text", outputCol="words") +hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") +lr = LogisticRegression(maxIter=10, regParam=0.01) +pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + +# Fit the pipeline to training documents. +model = pipeline.fit(training) + +# Prepare test documents, which are unlabeled (id, text) tuples. +test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) + +# Make predictions on test documents and print columns of interest. +prediction = model.transform(test) +selected = prediction.select("id", "text", "prediction") +for row in selected.collect(): + print(row) + +{% endhighlight %} +
    + +
    + +## Example: model selection via cross-validation + +An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. +`Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. + +Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). +`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. +`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. + +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` +method in each of these evaluators. + +The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. +`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +The following example demonstrates using `CrossValidator` to select from a grid of parameters. +To help construct the parameter grid, we use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. + +Note that cross-validation over a grid of parameters is expensive. +E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds. This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained. +In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common). +In other words, using `CrossValidator` can be very expensive. +However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.Row + +// Prepare training data from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) +)).toDF("id", "text", "label") + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") +val hashingTF = new HashingTF() + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") +val lr = new LogisticRegression() + .setMaxIter(10) +val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, +// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. +val paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) + .addGrid(lr.regParam, Array(0.1, 0.01)) + .build() + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice + +// Run cross-validation, and choose the best set of parameters. +val cvModel = cv.fit(training) + +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") + +// Make predictions on test documents. cvModel uses the best model found (lrModel). +cvModel.transform(test) + .select("id", "text", "probability", "prediction") + .collect() + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.tuning.CrossValidator; +import org.apache.spark.ml.tuning.CrossValidatorModel; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; + +// Labeled and unlabeled instance types. +// Spark SQL can infer schema from Java Beans. +public class Document implements Serializable { + private long id; + private String text; + + public Document(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { return this.id; } + public void setId(long id) { this.id = id; } + + public String getText() { return this.text; } + public void setText(String text) { this.text = text; } +} + +public class LabeledDocument extends Document implements Serializable { + private double label; + + public LabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { return this.label; } + public void setLabel(double label) { this.label = label; } +} + + +// Prepare training documents, which are labeled. +DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0), + new LabeledDocument(4L, "b spark who", 1.0), + new LabeledDocument(5L, "g d a y", 0.0), + new LabeledDocument(6L, "spark fly", 1.0), + new LabeledDocument(7L, "was mapreduce", 0.0), + new LabeledDocument(8L, "e spark program", 1.0), + new LabeledDocument(9L, "a e c l", 0.0), + new LabeledDocument(10L, "spark compile", 1.0), + new LabeledDocument(11L, "hadoop software", 0.0) +), LabeledDocument.class); + +// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. +Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); +HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); +LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, +// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) + .addGrid(lr.regParam(), new double[]{0.1, 0.01}) + .build(); + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2); // Use 3+ in practice + +// Run cross-validation, and choose the best set of parameters. +CrossValidatorModel cvModel = cv.fit(training); + +// Prepare test documents, which are unlabeled. +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop") +), Document.class); + +// Make predictions on test documents. cvModel uses the best model found (lrModel). +DataFrame predictions = cvModel.transform(test); +for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); +} + +{% endhighlight %} +
    + +
    + +## Example: model selection via train validation split +In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. +`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in + case of `CrossValidator`. It is therefore less expensive, + but will not produce as reliable results when the training dataset is not sufficiently large. + +`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, +and an `Evaluator`. +It begins by splitting the dataset into two parts using `trainRatio` parameter +which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), +`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. +Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. +For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. +The `ParamMap` which produces the best evaluation metric is selected as the best option. +`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +
    + +
    +{% highlight scala %} +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} + +// Prepare training and test data. +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") +val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + +val lr = new LinearRegression() + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) + +// Run train validation split, and choose the best set of parameters. +val model = trainValidationSplit.fit(training) + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show() + +{% endhighlight %} +
    + +
    +{% highlight java %} +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.sql.DataFrame; + +DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + +// Prepare training and test data. +DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); +DataFrame training = splits[0]; +DataFrame test = splits[1]; + +LinearRegression lr = new LinearRegression(); + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation + +// Run train validation split, and choose the best set of parameters. +TrainValidationSplitModel model = trainValidationSplit.fit(training); + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show(); + +{% endhighlight %} +
    + +
    \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 43772adcf26e..3bc2b780601c 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -66,15 +66,14 @@ We list major functionality from both below, with links to detailed guides. # spark.ml: high-level APIs for ML pipelines -**[spark.ml programming guide](ml-guide.html)** provides an overview of the Pipelines API and major -concepts. It also contains sections on using algorithms within the Pipelines API, for example: - -* [Feature extraction, transformation, and selection](ml-features.html) +* [Overview: estimators, transformers and pipelines](ml-intro.html) +* [Extracting, transforming and selecting features](ml-features.html) +* [Classification and regression](ml-classification-regression.html) * [Clustering](ml-clustering.html) -* [Decision trees for classification and regression](ml-decision-tree.html) -* [Ensembles](ml-ensembles.html) -* [Linear methods with elastic net regularization](ml-linear-methods.html) -* [Multilayer perceptron classifier](ml-ann.html) +* [Advanced topics](ml-advanced.html) + +Some techniques are not available yet in spark.ml, most notably dimensionality reduction +Users can seemlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. # Dependencies From a0046e379bee0852c39ece4ea719cde70d350b0e Mon Sep 17 00:00:00 2001 From: Dominik Dahlem Date: Tue, 8 Dec 2015 18:54:10 -0800 Subject: [PATCH 1709/1824] [SPARK-11343][ML] Documentation of float and double prediction/label columns in RegressionEvaluator felixcheung , mengxr Just added a message to require() Author: Dominik Dahlem Closes #9598 from dahlem/ddahlem_regression_evaluator_double_predictions_message_04112015. --- .../apache/spark/ml/evaluation/RegressionEvaluator.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index daaa174a086e..b6b25ecd01b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -73,10 +73,15 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema + val predictionColName = $(predictionCol) val predictionType = schema($(predictionCol)).dataType - require(predictionType == FloatType || predictionType == DoubleType) + require(predictionType == FloatType || predictionType == DoubleType, + s"Prediction column $predictionColName must be of type float or double, " + + s" but not $predictionType") + val labelColName = $(labelCol) val labelType = schema($(labelCol)).dataType - require(labelType == FloatType || labelType == DoubleType) + require(labelType == FloatType || labelType == DoubleType, + s"Label column $labelColName must be of type float or double, but not $labelType") val predictionAndLabels = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) From 3934562d34bbe08d91c54b4bbee27870e93d7571 Mon Sep 17 00:00:00 2001 From: Fei Wang Date: Tue, 8 Dec 2015 21:32:31 -0800 Subject: [PATCH 1710/1824] [SPARK-12222] [CORE] Deserialize RoaringBitmap using Kryo serializer throw Buffer underflow exception Jira: https://issues.apache.org/jira/browse/SPARK-12222 Deserialize RoaringBitmap using Kryo serializer throw Buffer underflow exception: ``` com.esotericsoftware.kryo.KryoException: Buffer underflow. at com.esotericsoftware.kryo.io.Input.require(Input.java:156) at com.esotericsoftware.kryo.io.Input.skip(Input.java:131) at com.esotericsoftware.kryo.io.Input.skip(Input.java:264) ``` This is caused by a bug of kryo's `Input.skip(long count)`(https://github.com/EsotericSoftware/kryo/issues/119) and we call this method in `KryoInputDataInputBridge`. Instead of upgrade kryo's version, this pr bypass the kryo's `Input.skip(long count)` by directly call another `skip` method in kryo's Input.java(https://github.com/EsotericSoftware/kryo/blob/kryo-2.21/src/com/esotericsoftware/kryo/io/Input.java#L124), i.e. write the bug-fixed version of `Input.skip(long count)` in KryoInputDataInputBridge's `skipBytes` method. more detail link to https://github.com/apache/spark/pull/9748#issuecomment-162860246 Author: Fei Wang Closes #10213 from scwf/patch-1. --- .../spark/serializer/KryoSerializer.scala | 10 ++++++- .../serializer/KryoSerializerSuite.scala | 28 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 62d445f3d7bd..cb2ac5ea167e 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -400,7 +400,15 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat override def readUTF(): String = input.readString() // readString in kryo does utf8 override def readInt(): Int = input.readInt() override def readUnsignedShort(): Int = input.readShortUnsigned() - override def skipBytes(n: Int): Int = input.skip(n.toLong).toInt + override def skipBytes(n: Int): Int = { + var remaining: Long = n + while (remaining > 0) { + val skip = Math.min(Integer.MAX_VALUE, remaining).asInstanceOf[Int] + input.skip(skip) + remaining -= skip + } + n + } override def readFully(b: Array[Byte]): Unit = input.read(b) override def readFully(b: Array[Byte], off: Int, len: Int): Unit = input.read(b, off, len) override def readLine(): String = throw new UnsupportedOperationException("readLine") diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index f81fe3113106..9fcc22b608c6 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,17 +17,21 @@ package org.apache.spark.serializer -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileOutputStream, FileInputStream} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} + +import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.util.Utils import org.apache.spark.storage.BlockManagerId class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { @@ -350,6 +354,28 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { assert(thrown.getMessage.contains(kryoBufferMaxProperty)) } + test("SPARK-12222: deserialize RoaringBitmap throw Buffer underflow exception") { + val dir = Utils.createTempDir() + val tmpfile = dir.toString + "/RoaringBitmap" + val outStream = new FileOutputStream(tmpfile) + val output = new KryoOutput(outStream) + val bitmap = new RoaringBitmap + bitmap.add(1) + bitmap.add(3) + bitmap.add(5) + bitmap.serialize(new KryoOutputDataOutputBridge(output)) + output.flush() + output.close() + + val inStream = new FileInputStream(tmpfile) + val input = new KryoInput(inStream) + val ret = new RoaringBitmap + ret.deserialize(new KryoInputDataInputBridge(input)) + input.close() + assert(ret == bitmap) + Utils.deleteRecursively(dir) + } + test("getAutoReset") { val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] assert(ser.getAutoReset) From f6883bb7afa7d5df480e1c2b3db6cb77198550be Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 9 Dec 2015 15:15:30 +0800 Subject: [PATCH 1711/1824] [SPARK-11676][SQL] Parquet filter tests all pass if filters are not really pushed down Currently Parquet predicate tests all pass even if filters are not pushed down or this is disabled. In this PR, For checking evaluating filters, Simply it makes the expression from `expression.Filter` and then try to create filters just like Spark does. For checking the results, this manually accesses to the child rdd (of `expression.Filter`) and produces the results which should be filtered properly, and then compares it to expected values. Now, if filters are not pushed down or this is disabled, this throws exceptions. Author: hyukjinkwon Closes #9659 from HyukjinKwon/SPARK-11676. --- .../parquet/ParquetFilterSuite.scala | 69 +++++++++++-------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index cc5aae03d551..daf41bc292cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -50,27 +50,33 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation, _)) => filters - }.flatten.reduceLeftOption(_ && _) - assert(maybeAnalyzedPredicate.isDefined) - - val selectedFilters = maybeAnalyzedPredicate.flatMap(DataSourceStrategy.translateFilter) - assert(selectedFilters.nonEmpty) - - selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") - maybeFilter.foreach { f => - // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) - assert(f.getClass === filterClass) + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[ParquetRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(relation: ParquetRelation, _)) => + maybeRelation = Some(relation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + selectedFilters.foreach { pred => + val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") + maybeFilter.foreach { f => + // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) + assert(f.getClass === filterClass) + } } + checker(stripSparkFilter(query), expected) } - checker(query, expected) } } @@ -104,6 +110,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val childRDD = df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + .map(row => Row.fromSeq(row.toSeq(schema))) + + sqlContext.createDataFrame(childRDD, schema) + } + test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -347,19 +368,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) val df = sqlContext.read.parquet(path).filter("a = 2") - // This is the source RDD without Spark-side filtering. - val childRDD = - df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - // The result should be single row. // When a filter is pushed to Parquet, Parquet can apply it to every row. // So, we can check the number of rows returned from the Parquet // to make sure our filter pushdown work. - assert(childRDD.count == 1) + assert(stripSparkFilter(df).count == 1) } } } From a113216865fd45ea39ae8f104e784af2cf667dcf Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 9 Dec 2015 15:09:40 +0000 Subject: [PATCH 1712/1824] [SPARK-12031][CORE][BUG] Integer overflow when do sampling Author: uncleGen Closes #10023 from uncleGen/1.6-bugfix. --- .../src/main/scala/org/apache/spark/Partitioner.scala | 4 ++-- .../org/apache/spark/util/random/SamplingUtils.scala | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e4df7af81a6d..ef9a2dab1c10 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -253,7 +253,7 @@ private[spark] object RangePartitioner { */ def sketch[K : ClassTag]( rdd: RDD[K], - sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { + sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => @@ -262,7 +262,7 @@ private[spark] object RangePartitioner { iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) }.collect() - val numItems = sketched.map(_._2.toLong).sum + val numItems = sketched.map(_._2).sum (numItems, sketched) } diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index c9a864ae6277..f98932a47016 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -34,7 +34,7 @@ private[spark] object SamplingUtils { input: Iterator[T], k: Int, seed: Long = Random.nextLong()) - : (Array[T], Int) = { + : (Array[T], Long) = { val reservoir = new Array[T](k) // Put the first k elements in the reservoir. var i = 0 @@ -52,16 +52,17 @@ private[spark] object SamplingUtils { (trimReservoir, i) } else { // If input size > k, continue the sampling process. + var l = i.toLong val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() - val replacementIndex = rand.nextInt(i) + val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { - reservoir(replacementIndex) = item + reservoir(replacementIndex.toInt) = item } - i += 1 + l += 1 } - (reservoir, i) + (reservoir, l) } } From 6e1c55eac4849669e119ce0d51f6d051830deb9f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 9 Dec 2015 23:30:42 +0800 Subject: [PATCH 1713/1824] [SPARK-12012][SQL] Show more comprehensive PhysicalRDD metadata when visualizing SQL query plan This PR adds a `private[sql]` method `metadata` to `SparkPlan`, which can be used to describe detail information about a physical plan during visualization. Specifically, this PR uses this method to provide details of `PhysicalRDD`s translated from a data source relation. For example, a `ParquetRelation` converted from Hive metastore table `default.psrc` is now shown as the following screenshot: ![image](https://cloud.githubusercontent.com/assets/230655/11526657/e10cb7e6-9916-11e5-9afa-f108932ec890.png) And here is the screenshot for a regular `ParquetRelation` (not converted from Hive metastore table) loaded from a really long path: ![output](https://cloud.githubusercontent.com/assets/230655/11680582/37c66460-9e94-11e5-8f50-842db5309d5a.png) Author: Cheng Lian Closes #10004 from liancheng/spark-12012.physical-rdd-metadata. --- python/pyspark/sql/dataframe.py | 2 +- .../spark/sql/execution/ExistingRDD.scala | 19 +++++--- .../spark/sql/execution/SparkPlan.scala | 5 +++ .../spark/sql/execution/SparkPlanInfo.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../datasources/DataSourceStrategy.scala | 22 +++++++-- .../datasources/parquet/ParquetRelation.scala | 10 +++++ .../sql/execution/ui/SparkPlanGraph.scala | 45 +++++++++++-------- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 ++- 11 files changed, 87 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 746bb55e14f2..78ab475eb466 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -213,7 +213,7 @@ def explain(self, extended=False): >>> df.explain() == Physical Plan == - Scan PhysicalRDD[age#0,name#1] + Scan ExistingRDD[age#0,name#1] >>> df.explain(True) == Parsed Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 623348f6768a..b8a43025882e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -97,22 +97,31 @@ private[sql] case class LogicalRDD( private[sql] case class PhysicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], - extraInformation: String, + override val nodeName: String, + override val metadata: Map[String, String] = Map.empty, override val outputsUnsafeRows: Boolean = false) extends LeafNode { protected override def doExecute(): RDD[InternalRow] = rdd - override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]") + override def simpleString: String = { + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" + s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}" + } } private[sql] object PhysicalRDD { + // Metadata keys + val INPUT_PATHS = "InputPaths" + val PUSHED_FILTERS = "PushedFilters" + def createFromDataSource( output: Seq[Attribute], rdd: RDD[InternalRow], relation: BaseRelation, - extraInformation: String = ""): PhysicalRDD = { - PhysicalRDD(output, rdd, relation.toString + extraInformation, - relation.isInstanceOf[HadoopFsRelation]) + metadata: Map[String, String] = Map.empty): PhysicalRDD = { + // All HadoopFsRelations output UnsafeRows + val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation] + PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a78177751c9d..ec98f8104134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -67,6 +67,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ super.makeCopy(newArgs) } + /** + * Return all metadata that describes more details of this SparkPlan. + */ + private[sql] def metadata: Map[String, String] = Map.empty + /** * Return all metrics containing metrics of this SparkPlan. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 486ce34064e4..4f750ad13ab8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -30,6 +30,7 @@ class SparkPlanInfo( val nodeName: String, val simpleString: String, val children: Seq[SparkPlanInfo], + val metadata: Map[String, String], val metrics: Seq[SQLMetricInfo]) private[sql] object SparkPlanInfo { @@ -41,6 +42,6 @@ private[sql] object SparkPlanInfo { } val children = plan.children.map(fromSparkPlan) - new SparkPlanInfo(plan.nodeName, plan.simpleString, children, metrics) + new SparkPlanInfo(plan.nodeName, plan.simpleString, children, plan.metadata, metrics) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f67c951bc066..25e98c0bdd43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -363,7 +363,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => apply(child) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 544d5eccec03..8a15a51d825e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala @@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} +import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} @@ -315,7 +318,20 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) - val pushedFiltersString = pushedFilters.mkString(" PushedFilter: [", ",", "] ") + val metadata: Map[String, String] = { + val pairs = ArrayBuffer.empty[(String, String)] + + if (pushedFilters.nonEmpty) { + pairs += (PUSHED_FILTERS -> pushedFilters.mkString("[", ", ", "]")) + } + + relation.relation match { + case r: HadoopFsRelation => pairs += INPUT_PATHS -> r.paths.mkString(", ") + case _ => + } + + pairs.toMap + } if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && @@ -334,7 +350,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, pushedFiltersString) + relation.relation, metadata) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { // Don't request columns that are only referenced by pushed filters. @@ -344,7 +360,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, pushedFiltersString) + relation.relation, metadata) execution.Project( projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index bb3e2786978c..1af2a394f399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -146,6 +146,12 @@ private[sql] class ParquetRelation( meta } + override def toString: String = { + parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map { tableName => + s"${getClass.getSimpleName}: $tableName" + }.getOrElse(super.toString) + } + override def equals(other: Any): Boolean = other match { case that: ParquetRelation => val schemaEquality = if (shouldMergeSchemas) { @@ -521,6 +527,10 @@ private[sql] object ParquetRelation extends Logging { // internally. private[sql] val METASTORE_SCHEMA = "metastoreSchema" + // If a ParquetRelation is converted from a Hive metastore table, this option is set to the + // original Hive table name. + private[sql] val METASTORE_TABLE_NAME = "metastoreTableName" + /** * If parquet's block size (row group size) setting is larger than the min split size, * we use parquet's block size setting as the min split size. Otherwise, we will create diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 7af0ff09c5c6..3a6eff939982 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -66,7 +66,9 @@ private[sql] object SparkPlanGraph { SQLMetrics.getMetricParam(metric.metricParam)) } val node = SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, metrics) + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, planInfo.metadata, metrics) + nodes += node val childrenNodes = planInfo.children.map( child => buildSparkPlanGraphNode(child, nodeIdGenerator, nodes, edges)) @@ -85,26 +87,33 @@ private[sql] object SparkPlanGraph { * @param metrics metrics that this SparkPlan node will track */ private[ui] case class SparkPlanGraphNode( - id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) { + id: Long, + name: String, + desc: String, + metadata: Map[String, String], + metrics: Seq[SQLPlanMetric]) { def makeDotNode(metricsValue: Map[Long, String]): String = { - val values = { - for (metric <- metrics; - value <- metricsValue.get(metric.accumulatorId)) yield { - metric.name + ": " + value - } + val builder = new mutable.StringBuilder(name) + + val values = for { + metric <- metrics + value <- metricsValue.get(metric.accumulatorId) + } yield { + metric.name + ": " + value } - val label = if (values.isEmpty) { - name - } else { - // If there are metrics, display all metrics in a separate line. We should use an escaped - // "\n" here to follow the dot syntax. - // - // Note: whitespace between two "\n"s is to create an empty line between the name of - // SparkPlan and metrics. If removing it, it won't display the empty line in UI. - name + "\\n \\n" + values.mkString("\\n") - } - s""" $id [label="$label"];""" + + if (values.nonEmpty) { + // If there are metrics, display each entry in a separate line. We should use an escaped + // "\n" here to follow the dot syntax. + // + // Note: whitespace between two "\n"s is to create an empty line between the name of + // SparkPlan and metrics. If removing it, it won't display the empty line in UI. + builder ++= "\\n \\n" + builder ++= values.mkString("\\n") + } + + s""" $id [label="${builder.toString()}"];""" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 9ace25dc7d21..fc8ce6901dfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -422,7 +422,7 @@ abstract class HadoopFsRelation private[sql]( parameters: Map[String, String]) extends BaseRelation with FileRelation with Logging { - override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") + override def toString: String = getClass.getSimpleName def this() = this(None, Map.empty[String, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a4626259b282..2fb439f50117 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -169,7 +169,7 @@ class PlannerSuite extends SharedSQLContext { withTempTable("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.executedPlan - assert(exp.toString.contains("PushedFilter: [EqualTo(key,15)]")) + assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9a981d02ad67..08b291e08823 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -411,7 +411,12 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // evil case insensitivity issue, which is reconciled within `ParquetRelation`. val parquetOptions = Map( ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, + ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( + metastoreRelation.tableName, + Some(metastoreRelation.databaseName) + ).unquotedString + ) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) From 22b9a8740d51289434553d19b6b1ac34aecdc09a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 9 Dec 2015 16:45:13 +0000 Subject: [PATCH 1714/1824] [SPARK-10299][ML] word2vec should allow users to specify the window size Currently word2vec has the window hard coded at 5, some users may want different sizes (for example if using on n-gram input or similar). User request comes from http://stackoverflow.com/questions/32231975/spark-word2vec-window-size . Author: Holden Karau Author: Holden Karau Closes #8513 from holdenk/SPARK-10299-word2vec-should-allow-users-to-specify-the-window-size. --- .../apache/spark/ml/feature/Word2Vec.scala | 15 +++++++ .../apache/spark/mllib/feature/Word2Vec.scala | 11 ++++- .../spark/ml/feature/Word2VecSuite.scala | 43 +++++++++++++++++-- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index a8d61b6dea00..f105a983a34f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -49,6 +49,17 @@ private[feature] trait Word2VecBase extends Params /** @group getParam */ def getVectorSize: Int = $(vectorSize) + /** + * The window size (context words from [-window, window]) default 5. + * @group expertParam + */ + final val windowSize = new IntParam( + this, "windowSize", "the window size (context words from [-window, window])") + setDefault(windowSize -> 5) + + /** @group expertGetParam */ + def getWindowSize: Int = $(windowSize) + /** * Number of partitions for sentences of words. * Default: 1 @@ -106,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setVectorSize(value: Int): this.type = set(vectorSize, value) + /** @group expertSetParam */ + def setWindowSize(value: Int): this.type = set(windowSize, value) + /** @group setParam */ def setStepSize(value: Double): this.type = set(stepSize, value) @@ -131,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] .setNumPartitions($(numPartitions)) .setSeed($(seed)) .setVectorSize($(vectorSize)) + .setWindowSize($(windowSize)) .fit(input) copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 23b1514e3080..1f400e1430eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -125,6 +125,15 @@ class Word2Vec extends Serializable with Logging { this } + /** + * Sets the window of words (default: 5) + */ + @Since("1.6.0") + def setWindowSize(window: Int): this.type = { + this.window = window + this + } + /** * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). @@ -141,7 +150,7 @@ class Word2Vec extends Serializable with Logging { private val MAX_SENTENCE_LENGTH = 1000 /** context words from [-window, window] */ - private val window = 5 + private var window = 5 private var trainWordsCount = 0 private var vocabSize = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a773244cd735..d561bbbb2552 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec") { - val sqlContext = new SQLContext(sc) + + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("getVectors") { - val sqlContext = new SQLContext(sc) + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val sqlContext = new SQLContext(sc) + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul expectedSimilarity.zip(similarity).map { case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) } + } + + test("window size") { + + val sqlContext = this.sqlContext + import sqlContext.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + + val model = new Word2Vec() + .setVectorSize(3) + .setWindowSize(2) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + val (synonyms, similarity) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + // Increase the window size + val biggerModel = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .setWindowSize(10) + .fit(docDF) + + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + // The similarity score should be very different with the larger window + assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) } test("Word2Vec read/write") { From 6900f0173790ad2fa4c79a426bd2dec2d149daa2 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 9 Dec 2015 09:50:43 -0800 Subject: [PATCH 1715/1824] [SPARK-10582][YARN][CORE] Fix AM failure situation for dynamic allocation Because of AM failure, the target executor number between driver and AM will be different, which will lead to unexpected behavior in dynamic allocation. So when AM is re-registered with driver, state in `ExecutorAllocationManager` and `CoarseGrainedSchedulerBacked` should be reset. This issue is originally addressed in #8737 , here re-opened again. Thanks a lot KaiXinXiaoLei for finding this issue. andrewor14 and vanzin would you please help to review this, thanks a lot. Author: jerryshao Closes #9963 from jerryshao/SPARK-10582. --- .../spark/ExecutorAllocationManager.scala | 18 +++- .../CoarseGrainedSchedulerBackend.scala | 19 +++++ .../ExecutorAllocationManagerSuite.scala | 84 +++++++++++++++++++ .../cluster/YarnSchedulerBackend.scala | 23 +++++ 4 files changed, 142 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 34c32ce31296..6176e258989d 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -89,6 +89,8 @@ private[spark] class ExecutorAllocationManager( private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Integer.MAX_VALUE) + private val initialNumExecutors = conf.getInt("spark.dynamicAllocation.initialExecutors", + minNumExecutors) // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( @@ -121,8 +123,7 @@ private[spark] class ExecutorAllocationManager( // The desired number of executors at this moment in time. If all our executors were to die, this // is the number of executors we would immediately want from the cluster manager. - private var numExecutorsTarget = - conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) + private var numExecutorsTarget = initialNumExecutors // Executors that have been requested to be removed but have not been killed yet private val executorsPendingToRemove = new mutable.HashSet[String] @@ -240,6 +241,19 @@ private[spark] class ExecutorAllocationManager( executor.awaitTermination(10, TimeUnit.SECONDS) } + /** + * Reset the allocation manager to the initial state. Currently this will only be called in + * yarn-client mode when AM re-registers after a failure. + */ + def reset(): Unit = synchronized { + initializing = true + numExecutorsTarget = initialNumExecutors + numExecutorsToAdd = 1 + + executorsPendingToRemove.clear() + removeTimes.clear() + } + /** * The maximum number of executors we would need under the current load to satisfy all running * and pending tasks, rounded up. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 505c161141c8..7efe16749e59 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -341,6 +341,25 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } + /** + * Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only + * be called in the yarn-client mode when AM re-registers after a failure, also dynamic + * allocation is enabled. + * */ + protected def reset(): Unit = synchronized { + if (Utils.isDynamicAllocationEnabled(conf)) { + numPendingExecutors = 0 + executorsPendingToRemove.clear() + + // Remove all the lingering executors that should be removed but not yet. The reason might be + // because (1) disconnected event is not yet received; (2) executors die silently. + executorDataMap.toMap.foreach { case (eid, _) => + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered."))) + } + } + } + override def reviveOffers() { driverEndpoint.send(ReviveOffers) } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 116f027a0f98..fedfbd547b91 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -805,6 +805,90 @@ class ExecutorAllocationManagerSuite assert(maxNumExecutorsNeeded(manager) === 1) } + test("reset the state of allocation manager") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + + // Allocation manager is reset when adding executor requests are sent without reporting back + // executor added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsTarget(manager) === 4) + assert(addExecutors(manager) === 1) + assert(numExecutorsTarget(manager) === 5) + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + + // Allocation manager is reset when executors are added. + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorIds(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + + // Allocation manager is reset when executors are pending to remove + addExecutors(manager) + addExecutors(manager) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 5) + + onExecutorAdded(manager, "first") + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + onExecutorAdded(manager, "fifth") + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + removeExecutor(manager, "first") + removeExecutor(manager, "second") + assert(executorsPendingToRemove(manager) === Set("first", "second")) + assert(executorIds(manager) === Set("first", "second", "third", "fourth", "fifth")) + + + // Cluster manager lost will make all the live executors lost, so here simulate this behavior + onExecutorRemoved(manager, "first") + onExecutorRemoved(manager, "second") + onExecutorRemoved(manager, "third") + onExecutorRemoved(manager, "fourth") + onExecutorRemoved(manager, "fifth") + + manager.reset() + + assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsToAdd(manager) === 1) + assert(executorsPendingToRemove(manager) === Set.empty) + assert(removeTimes(manager) === Map.empty) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index e3dd87798f01..1431bceb256a 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -60,6 +60,9 @@ private[spark] abstract class YarnSchedulerBackend( /** Scheduler extension services. */ private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + // Flag to specify whether this schedulerBackend should be reset. + private var shouldResetOnAmRegister = false + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -155,6 +158,16 @@ private[spark] abstract class YarnSchedulerBackend( new YarnDriverEndpoint(rpcEnv, properties) } + /** + * Reset the state of SchedulerBackend to the initial state. This is happened when AM is failed + * and re-registered itself to driver after a failure. The stale state in driver should be + * cleaned. + */ + override protected def reset(): Unit = { + super.reset() + sc.executorAllocationManager.foreach(_.reset()) + } + /** * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. * This endpoint communicates with the executors and queries the AM for an executor's exit @@ -218,6 +231,8 @@ private[spark] abstract class YarnSchedulerBackend( case None => logWarning("Attempted to check for an executor loss reason" + " before the AM has registered!") + driverEndpoint.askWithRetry[Boolean]( + RemoveExecutor(executorId, SlaveLost("AM is not yet registered."))) } } @@ -225,6 +240,13 @@ private[spark] abstract class YarnSchedulerBackend( case RegisterClusterManager(am) => logInfo(s"ApplicationMaster registered as $am") amEndpoint = Option(am) + if (!shouldResetOnAmRegister) { + shouldResetOnAmRegister = true + } else { + // AM is already registered before, this potentially means that AM failed and + // a new one registered after the failure. This will only happen in yarn-client mode. + reset() + } case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) @@ -270,6 +292,7 @@ private[spark] abstract class YarnSchedulerBackend( override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (amEndpoint.exists(_.address == remoteAddress)) { logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + amEndpoint = None } } From 442a7715a590ba2ea2446c73b1f914a16ae0ed4b Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Wed, 9 Dec 2015 10:25:38 -0800 Subject: [PATCH 1716/1824] [SPARK-12241][YARN] Improve failure reporting in Yarn client obtainTokenForHBase() This lines up the HBase token logic with that done for Hive in SPARK-11265: reflection with only CFNE being swallowed. There is a test, one which doesn't try to put HBase on the yarn/test class and really do the reflection (the way the hive introspection does). If people do want that then it could be added with careful POM work +also: cut an incorrect comment from the Hive test case before copying it, and a couple of imports that may have been related to the hive test in the past. Author: Steve Loughran Closes #10227 from steveloughran/stevel/patches/SPARK-12241-obtainTokenForHBase. --- .../org/apache/spark/deploy/yarn/Client.scala | 32 ++---------- .../deploy/yarn/YarnSparkHadoopUtil.scala | 51 ++++++++++++++++++- .../yarn/YarnSparkHadoopUtilSuite.scala | 12 ++++- 3 files changed, 64 insertions(+), 31 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index f0590d2d222e..7742ec92eb4e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1369,40 +1369,16 @@ object Client extends Logging { } /** - * Obtain security token for HBase. + * Obtain a security token for HBase. */ def obtainTokenForHBase( sparkConf: SparkConf, conf: Configuration, credentials: Credentials): Unit = { if (shouldGetTokens(sparkConf, "hbase") && UserGroupInformation.isSecurityEnabled) { - val mirror = universe.runtimeMirror(getClass.getClassLoader) - - try { - val confCreate = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). - getMethod("create", classOf[Configuration]) - val obtainToken = mirror.classLoader. - loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). - getMethod("obtainToken", classOf[Configuration]) - - logDebug("Attempting to fetch HBase security token.") - - val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] - if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { - val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] - credentials.addToken(token.getService, token) - logInfo("Added HBase security token to credentials.") - } - } catch { - case e: java.lang.NoSuchMethodException => - logInfo("HBase Method not found: " + e) - case e: java.lang.ClassNotFoundException => - logDebug("HBase Class not found: " + e) - case e: java.lang.NoClassDefFoundError => - logDebug("HBase Class not found: " + e) - case e: Exception => - logError("Exception when obtaining HBase security token: " + e) + YarnSparkHadoopUtil.get.obtainTokenForHBase(conf).foreach { token => + credentials.addToken(token.getService, token) + logInfo("Added HBase security token to credentials.") } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index a290ebeec900..36a2d6142988 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.{Master, JobConf} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.security.token.Token +import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -216,6 +216,55 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { None } } + + /** + * Obtain a security token for HBase. + * + * Requirements + * + * 1. `"hbase.security.authentication" == "kerberos"` + * 2. The HBase classes `HBaseConfiguration` and `TokenUtil` could be loaded + * and invoked. + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if the requirements were met, `None` if not. + */ + def obtainTokenForHBase(conf: Configuration): Option[Token[TokenIdentifier]] = { + try { + obtainTokenForHBaseInner(conf) + } catch { + case e: ClassNotFoundException => + logInfo(s"HBase class not found $e") + logDebug("HBase class not found", e) + None + } + } + + /** + * Obtain a security token for HBase if `"hbase.security.authentication" == "kerberos"` + * + * @param conf Hadoop configuration; an HBase configuration is created + * from this. + * @return a token if one was needed + */ + def obtainTokenForHBaseInner(conf: Configuration): Option[Token[TokenIdentifier]] = { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + val confCreate = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). + getMethod("create", classOf[Configuration]) + val obtainToken = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). + getMethod("obtainToken", classOf[Configuration]) + val hbaseConf = confCreate.invoke(null, conf).asInstanceOf[Configuration] + if ("kerberos" == hbaseConf.get("hbase.security.authentication")) { + logDebug("Attempting to fetch HBase security token.") + Some(obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]]) + } else { + None + } + } + } object YarnSparkHadoopUtil { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index a70e66d39a64..3fafc91a166a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -27,7 +27,6 @@ import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.io.Text import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -259,7 +258,6 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastoreInner(hadoopConf, "alice") }) - // expect exception trapping code to unwind this hive-side exception assertNestedHiveException(intercept[InvocationTargetException] { util.obtainTokenForHiveMetastore(hadoopConf) }) @@ -276,6 +274,16 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging inner } + test("Obtain tokens For HBase") { + val hadoopConf = new Configuration() + hadoopConf.set("hbase.security.authentication", "kerberos") + val util = new YarnSparkHadoopUtil + intercept[ClassNotFoundException] { + util.obtainTokenForHBaseInner(hadoopConf) + } + util.obtainTokenForHBase(hadoopConf) should be (None) + } + // This test needs to live here because it depends on isYarnMode returning true, which can only // happen in the YARN module. test("security manager token generation") { From aec5ea000ebb8921f42f006b694ef26f5df67d83 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 9 Dec 2015 11:39:59 -0800 Subject: [PATCH 1717/1824] [SPARK-12165][SPARK-12189] Fix bugs in eviction of storage memory by execution This patch fixes a bug in the eviction of storage memory by execution. ## The bug: In general, execution should be able to evict storage memory when the total storage memory usage is greater than `maxMemory * spark.memory.storageFraction`. Due to a bug, however, Spark might wind up evicting no storage memory in certain cases where the storage memory usage was between `maxMemory * spark.memory.storageFraction` and `maxMemory`. For example, here is a regression test which illustrates the bug: ```scala val maxMemory = 1000L val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) // Since we used the default storage fraction (0.5), we should be able to allocate 500 bytes // of storage memory which are immune to eviction by execution memory pressure. // Acquire enough storage memory to exceed the storage region size assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) // At this point, storage is using 250 more bytes of memory than it is guaranteed, so execution // should be able to reclaim up to 250 bytes of storage memory. // Therefore, execution should now be able to require up to 500 bytes of memory: assert(mm.acquireExecutionMemory(500L, taskAttemptId, MemoryMode.ON_HEAP) === 500L) // <--- fails by only returning 250L assert(mm.storageMemoryUsed === 500L) assert(mm.executionMemoryUsed === 500L) assertEvictBlocksToFreeSpaceCalled(ms, 250L) ``` The problem relates to the control flow / interaction between `StorageMemoryPool.shrinkPoolToReclaimSpace()` and `MemoryStore.ensureFreeSpace()`. While trying to allocate the 500 bytes of execution memory, the `UnifiedMemoryManager` discovers that it will need to reclaim 250 bytes of memory from storage, so it calls `StorageMemoryPool.shrinkPoolToReclaimSpace(250L)`. This method, in turn, calls `MemoryStore.ensureFreeSpace(250L)`. However, `ensureFreeSpace()` first checks whether the requested space is less than `maxStorageMemory - storageMemoryUsed`, which will be true if there is any free execution memory because it turns out that `MemoryStore.maxStorageMemory = (maxMemory - onHeapExecutionMemoryPool.memoryUsed)` when the `UnifiedMemoryManager` is used. The control flow here is somewhat confusing (it grew to be messy / confusing over time / as a result of the merging / refactoring of several components). In the pre-Spark 1.6 code, `ensureFreeSpace` was called directly by the `MemoryStore` itself, whereas in 1.6 it's involved in a confusing control flow where `MemoryStore` calls `MemoryManager.acquireStorageMemory`, which then calls back into `MemoryStore.ensureFreeSpace`, which, in turn, calls `MemoryManager.freeStorageMemory`. ## The solution: The solution implemented in this patch is to remove the confusing circular control flow between `MemoryManager` and `MemoryStore`, making the storage memory acquisition process much more linear / straightforward. The key changes: - Remove a layer of inheritance which made the memory manager code harder to understand (53841174760a24a0df3eb1562af1f33dbe340eb9). - Move some bounds checks earlier in the call chain (13ba7ada77f87ef1ec362aec35c89a924e6987cb). - Refactor `ensureFreeSpace()` so that the part which evicts blocks can be called independently from the part which checks whether there is enough free space to avoid eviction (7c68ca09cb1b12f157400866983f753ac863380e). - Realize that this lets us remove a layer of overloads from `ensureFreeSpace` (eec4f6c87423d5e482b710e098486b3bbc4daf06). - Realize that `ensureFreeSpace()` can simply be replaced with an `evictBlocksToFreeSpace()` method which is called [after we've already figured out](https://github.com/apache/spark/blob/2dc842aea82c8895125d46a00aa43dfb0d121de9/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala#L88) how much memory needs to be reclaimed via eviction; (2dc842aea82c8895125d46a00aa43dfb0d121de9). Along the way, I fixed some problems with the mocks in `MemoryManagerSuite`: the old mocks would [unconditionally](https://github.com/apache/spark/blob/80a824d36eec9d9a9f092ee1741453851218ec73/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala#L84) report that a block had been evicted even if there was enough space in the storage pool such that eviction would be avoided. I also fixed a problem where `StorageMemoryPool._memoryUsed` might become negative due to freed memory being double-counted when excution evicts storage. The problem was that `StorageMemoryPoolshrinkPoolToFreeSpace` would [decrement `_memoryUsed`](https://github.com/apache/spark/commit/7c68ca09cb1b12f157400866983f753ac863380e#diff-935c68a9803be144ed7bafdd2f756a0fL133) even though `StorageMemoryPool.freeMemory` had already decremented it as each evicted block was freed. See SPARK-12189 for details. Author: Josh Rosen Author: Andrew Or Closes #10170 from JoshRosen/SPARK-12165. --- .../apache/spark/memory/MemoryManager.scala | 11 +- .../spark/memory/StaticMemoryManager.scala | 37 ++++- .../spark/memory/StorageMemoryPool.scala | 37 +++-- .../spark/memory/UnifiedMemoryManager.scala | 8 +- .../apache/spark/storage/MemoryStore.scala | 76 ++-------- .../spark/memory/MemoryManagerSuite.scala | 137 +++++++++--------- .../memory/StaticMemoryManagerSuite.scala | 52 ++++--- .../memory/UnifiedMemoryManagerSuite.scala | 76 +++++++--- 8 files changed, 230 insertions(+), 204 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index ceb8ea434e1b..ae9e1ac0e246 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -77,9 +77,7 @@ private[spark] abstract class MemoryManager( def acquireStorageMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) - } + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean /** * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. @@ -109,12 +107,7 @@ private[spark] abstract class MemoryManager( def acquireExecutionMemory( numBytes: Long, taskAttemptId: Long, - memoryMode: MemoryMode): Long = synchronized { - memoryMode match { - case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) - case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) - } - } + memoryMode: MemoryMode): Long /** * Release numBytes of execution memory belonging to the given task. diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 12a094306861..3554b558f212 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -49,19 +49,50 @@ private[spark] class StaticMemoryManager( } // Max number of bytes worth of blocks to evict when unrolling - private val maxMemoryToEvictForUnroll: Long = { + private val maxUnrollMemory: Long = { (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong } + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + if (numBytes > maxStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxStorageMemory bytes)") + false + } else { + storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) + } + } + override def acquireUnrollMemory( blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { val currentUnrollMemory = storageMemoryPool.memoryStore.currentUnrollMemory - val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory) - val numBytesToFree = math.min(numBytes, maxNumBytesToFree) + val freeMemory = storageMemoryPool.memoryFree + // When unrolling, we will use all of the existing free memory, and, if necessary, + // some extra space freed from evicting cached blocks. We must place a cap on the + // amount of memory to be evicted by unrolling, however, otherwise unrolling one + // big block can blow away the entire cache. + val maxNumBytesToFree = math.max(0, maxUnrollMemory - currentUnrollMemory - freeMemory) + // Keep it within the range 0 <= X <= maxNumBytesToFree + val numBytesToFree = math.max(0, math.min(maxNumBytesToFree, numBytes - freeMemory)) storageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) } + + private[memory] + override def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + } + } } diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index fc4f0357e9f1..70af83b5ee09 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -65,7 +65,8 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { - acquireMemory(blockId, numBytes, numBytes, evictedBlocks) + val numBytesToFree = math.max(0, numBytes - memoryFree) + acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) } /** @@ -73,7 +74,7 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w * * @param blockId the ID of the block we are acquiring storage memory for * @param numBytesToAcquire the size of this block - * @param numBytesToFree the size of space to be freed through evicting blocks + * @param numBytesToFree the amount of space to be freed through evicting blocks * @return whether all N bytes were successfully granted. */ def acquireMemory( @@ -84,16 +85,18 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w assert(numBytesToAcquire >= 0) assert(numBytesToFree >= 0) assert(memoryUsed <= poolSize) - memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) - // Register evicted blocks, if any, with the active task metrics - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + if (numBytesToFree > 0) { + memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } } // NOTE: If the memory store evicts blocks, then those evictions will synchronously call - // back into this StorageMemoryPool in order to free. Therefore, these variables should have - // been updated. + // back into this StorageMemoryPool in order to free memory. Therefore, these variables + // should have been updated. val enoughMemory = numBytesToAcquire <= memoryFree if (enoughMemory) { _memoryUsed += numBytesToAcquire @@ -121,18 +124,20 @@ private[memory] class StorageMemoryPool(lock: Object) extends MemoryPool(lock) w */ def shrinkPoolToFreeSpace(spaceToFree: Long): Long = lock.synchronized { // First, shrink the pool by reclaiming free memory: - val spaceFreedByReleasingUnusedMemory = Math.min(spaceToFree, memoryFree) + val spaceFreedByReleasingUnusedMemory = math.min(spaceToFree, memoryFree) decrementPoolSize(spaceFreedByReleasingUnusedMemory) - if (spaceFreedByReleasingUnusedMemory == spaceToFree) { - spaceFreedByReleasingUnusedMemory - } else { + val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory + if (remainingSpaceToFree > 0) { // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - memoryStore.ensureFreeSpace(spaceToFree - spaceFreedByReleasingUnusedMemory, evictedBlocks) + memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, evictedBlocks) val spaceFreedByEviction = evictedBlocks.map(_._2.memSize).sum - _memoryUsed -= spaceFreedByEviction + // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do + // not need to decrement _memoryUsed here. However, we do need to decrement the pool size. decrementPoolSize(spaceFreedByEviction) spaceFreedByReleasingUnusedMemory + spaceFreedByEviction + } else { + spaceFreedByReleasingUnusedMemory } } } diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 0f1ea9ab39c0..0b9f6a9dc052 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -100,7 +100,7 @@ private[spark] class UnifiedMemoryManager private[memory] ( case MemoryMode.OFF_HEAP => // For now, we only support on-heap caching of data, so we do not need to interact with // the storage pool when allocating off-heap memory. This will change in the future, though. - super.acquireExecutionMemory(numBytes, taskAttemptId, memoryMode) + offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) } } @@ -110,6 +110,12 @@ private[spark] class UnifiedMemoryManager private[memory] ( evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) assert(numBytes >= 0) + if (numBytes > maxStorageMemory) { + // Fail fast if the block simply won't fit + logInfo(s"Will not store $blockId as the required space ($numBytes bytes) exceeds our " + + s"memory limit ($maxStorageMemory bytes)") + return false + } if (numBytes > storageMemoryPool.memoryFree) { // There is not enough free memory in the storage pool, so try to borrow free memory from // the execution pool. diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 4dbac388e098..bdab8c2332fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -406,85 +406,41 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } /** - * Try to free up a given amount of space by evicting existing blocks. - * - * @param space the amount of memory to free, in bytes - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private[spark] def ensureFreeSpace( - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - ensureFreeSpace(None, space, droppedBlocks) - } - - /** - * Try to free up a given amount of space to store a block by evicting existing ones. - * - * @param space the amount of memory to free, in bytes - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private[spark] def ensureFreeSpace( - blockId: BlockId, - space: Long, - droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { - ensureFreeSpace(Some(blockId), space, droppedBlocks) - } - - /** - * Try to free up a given amount of space to store a particular block, but can fail if - * either the block is bigger than our memory or it would require replacing another block - * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that - * don't fit into memory that we want to avoid). - * - * @param blockId the ID of the block we are freeing space for, if any - * @param space the size of this block - * @param droppedBlocks a holder for blocks evicted in the process - * @return whether the requested free space is freed. - */ - private def ensureFreeSpace( + * Try to evict blocks to free up a given amount of space to store a particular block. + * Can fail if either the block is bigger than our memory or it would require replacing + * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for + * RDDs that don't fit into memory that we want to avoid). + * + * @param blockId the ID of the block we are freeing space for, if any + * @param space the size of this block + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether the requested free space is freed. + */ + private[spark] def evictBlocksToFreeSpace( blockId: Option[BlockId], space: Long, droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + assert(space > 0) memoryManager.synchronized { - val freeMemory = maxMemory - memoryUsed + var freedMemory = 0L val rddToAdd = blockId.flatMap(getRddId) val selectedBlocks = new ArrayBuffer[BlockId] - var selectedMemory = 0L - - logInfo(s"Ensuring $space bytes of free space " + - blockId.map { id => s"for block $id" }.getOrElse("") + - s"(free: $freeMemory, max: $maxMemory)") - - // Fail fast if the block simply won't fit - if (space > maxMemory) { - logInfo("Will not " + blockId.map { id => s"store $id" }.getOrElse("free memory") + - s" as the required space ($space bytes) exceeds our memory limit ($maxMemory bytes)") - return false - } - - // No need to evict anything if there is already enough free space - if (freeMemory >= space) { - return true - } - // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. entries.synchronized { val iterator = entries.entrySet().iterator() - while (freeMemory + selectedMemory < space && iterator.hasNext) { + while (freedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { selectedBlocks += blockId - selectedMemory += pair.getValue.size + freedMemory += pair.getValue.size } } } - if (freeMemory + selectedMemory >= space) { + if (freedMemory >= space) { logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index f55d435fa33a..555b640cb424 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -24,9 +24,10 @@ import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} import org.mockito.Matchers.{any, anyLong} -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite @@ -36,105 +37,105 @@ import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, StorageLevel /** * Helper trait for sharing code among [[MemoryManager]] tests. */ -private[memory] trait MemoryManagerSuite extends SparkFunSuite { +private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAfterEach { - import MemoryManagerSuite.DEFAULT_ENSURE_FREE_SPACE_CALLED + protected val evictedBlocks = new mutable.ArrayBuffer[(BlockId, BlockStatus)] + + import MemoryManagerSuite.DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED // Note: Mockito's verify mechanism does not provide a way to reset method call counts // without also resetting stubbed methods. Since our test code relies on the latter, - // we need to use our own variable to track invocations of `ensureFreeSpace`. + // we need to use our own variable to track invocations of `evictBlocksToFreeSpace`. /** - * The amount of free space requested in the last call to [[MemoryStore.ensureFreeSpace]] + * The amount of space requested in the last call to [[MemoryStore.evictBlocksToFreeSpace]]. * - * This set whenever [[MemoryStore.ensureFreeSpace]] is called, and cleared when the test - * code makes explicit assertions on this variable through [[assertEnsureFreeSpaceCalled]]. + * This set whenever [[MemoryStore.evictBlocksToFreeSpace]] is called, and cleared when the test + * code makes explicit assertions on this variable through + * [[assertEvictBlocksToFreeSpaceCalled]]. */ - private val ensureFreeSpaceCalled = new AtomicLong(DEFAULT_ENSURE_FREE_SPACE_CALLED) + private val evictBlocksToFreeSpaceCalled = new AtomicLong(0) + + override def beforeEach(): Unit = { + super.beforeEach() + evictedBlocks.clear() + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) + } /** - * Make a mocked [[MemoryStore]] whose [[MemoryStore.ensureFreeSpace]] method is stubbed. + * Make a mocked [[MemoryStore]] whose [[MemoryStore.evictBlocksToFreeSpace]] method is stubbed. * - * This allows our test code to release storage memory when [[MemoryStore.ensureFreeSpace]] - * is called without relying on [[org.apache.spark.storage.BlockManager]] and all of its - * dependencies. + * This allows our test code to release storage memory when these methods are called + * without relying on [[org.apache.spark.storage.BlockManager]] and all of its dependencies. */ protected def makeMemoryStore(mm: MemoryManager): MemoryStore = { - val ms = mock(classOf[MemoryStore]) - when(ms.ensureFreeSpace(anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 0)) - when(ms.ensureFreeSpace(any(), anyLong(), any())).thenAnswer(ensureFreeSpaceAnswer(mm, 1)) + val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS) + when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())) + .thenAnswer(evictBlocksToFreeSpaceAnswer(mm)) mm.setMemoryStore(ms) ms } /** - * Make an [[Answer]] that stubs [[MemoryStore.ensureFreeSpace]] with the right arguments. - */ - private def ensureFreeSpaceAnswer(mm: MemoryManager, numBytesPos: Int): Answer[Boolean] = { + * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. + * + * This is a significant simplification of the real method, which actually drops existing + * blocks based on the size of each block. Instead, here we simply release as many bytes + * as needed to ensure the requested amount of free space. This allows us to set up the + * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in + * many other dependencies. + * + * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that + * records the number of bytes this is called with. This variable is expected to be cleared + * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. + */ + private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Boolean] = { new Answer[Boolean] { override def answer(invocation: InvocationOnMock): Boolean = { val args = invocation.getArguments - require(args.size > numBytesPos, s"bad test: expected >$numBytesPos arguments " + - s"in ensureFreeSpace, found ${args.size}") - require(args(numBytesPos).isInstanceOf[Long], s"bad test: expected ensureFreeSpace " + - s"argument at index $numBytesPos to be a Long: ${args.mkString(", ")}") - val numBytes = args(numBytesPos).asInstanceOf[Long] - val success = mockEnsureFreeSpace(mm, numBytes) - if (success) { + val numBytesToFree = args(1).asInstanceOf[Long] + assert(numBytesToFree > 0) + require(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "bad test: evictBlocksToFreeSpace() variable was not reset") + evictBlocksToFreeSpaceCalled.set(numBytesToFree) + if (numBytesToFree <= mm.storageMemoryUsed) { + // We can evict enough blocks to fulfill the request for space + mm.releaseStorageMemory(numBytesToFree) args.last.asInstanceOf[mutable.Buffer[(BlockId, BlockStatus)]].append( - (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytes, 0L, 0L))) + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + // We need to add this call so that that the suite-level `evictedBlocks` is updated when + // execution evicts storage; in that case, args.last will not be equal to evictedBlocks + // because it will be a temporary buffer created inside of the MemoryManager rather than + // being passed in by the test code. + if (!(evictedBlocks eq args.last)) { + evictedBlocks.append( + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytesToFree, 0L, 0L))) + } + true + } else { + // No blocks were evicted because eviction would not free enough space. + false } - success - } - } - } - - /** - * Simulate the part of [[MemoryStore.ensureFreeSpace]] that releases storage memory. - * - * This is a significant simplification of the real method, which actually drops existing - * blocks based on the size of each block. Instead, here we simply release as many bytes - * as needed to ensure the requested amount of free space. This allows us to set up the - * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in - * many other dependencies. - * - * Every call to this method will set a global variable, [[ensureFreeSpaceCalled]], that - * records the number of bytes this is called with. This variable is expected to be cleared - * by the test code later through [[assertEnsureFreeSpaceCalled]]. - */ - private def mockEnsureFreeSpace(mm: MemoryManager, numBytes: Long): Boolean = mm.synchronized { - require(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, - "bad test: ensure free space variable was not reset") - // Record the number of bytes we freed this call - ensureFreeSpaceCalled.set(numBytes) - if (numBytes <= mm.maxStorageMemory) { - def freeMemory = mm.maxStorageMemory - mm.storageMemoryUsed - val spaceToRelease = numBytes - freeMemory - if (spaceToRelease > 0) { - mm.releaseStorageMemory(spaceToRelease) } - freeMemory >= numBytes - } else { - // We attempted to free more bytes than our max allowable memory - false } } /** - * Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters. + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is called with the given parameters. */ - protected def assertEnsureFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { - assert(ensureFreeSpaceCalled.get() === numBytes, - s"expected ensure free space to be called with $numBytes") - ensureFreeSpaceCalled.set(DEFAULT_ENSURE_FREE_SPACE_CALLED) + protected def assertEvictBlocksToFreeSpaceCalled(ms: MemoryStore, numBytes: Long): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === numBytes, + s"expected evictBlocksToFreeSpace() to be called with $numBytes") + evictBlocksToFreeSpaceCalled.set(DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED) } /** - * Assert that [[MemoryStore.ensureFreeSpace]] is NOT called. + * Assert that [[MemoryStore.evictBlocksToFreeSpace]] is NOT called. */ - protected def assertEnsureFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { - assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, - "ensure free space should not have been called!") + protected def assertEvictBlocksToFreeSpaceNotCalled[T](ms: MemoryStore): Unit = { + assert(evictBlocksToFreeSpaceCalled.get() === DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED, + "evictBlocksToFreeSpace() should not have been called!") + assert(evictedBlocks.isEmpty) } /** @@ -291,5 +292,5 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { } private object MemoryManagerSuite { - private val DEFAULT_ENSURE_FREE_SPACE_CALLED = -1L + private val DEFAULT_EVICT_BLOCKS_TO_FREE_SPACE_CALLED = -1L } diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 54cb28c389c2..6700b94f0f57 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.memory -import scala.collection.mutable.ArrayBuffer - import org.mockito.Mockito.when import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} +import org.apache.spark.storage.{MemoryStore, TestBlockId} class StaticMemoryManagerSuite extends MemoryManagerSuite { private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") - private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] /** * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies. @@ -85,33 +82,38 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) - // `ensureFreeSpace` should be called with the number of bytes requested - assertEnsureFreeSpaceCalled(ms, 10L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted assert(!mm.acquireStorageMemory(dummyBlock, maxStorageMem + 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, maxStorageMem + 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction assert(mm.acquireStorageMemory(dummyBlock, maxStorageMem, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1000L) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. assert(mm.storageMemoryUsed === 1000L) mm.releaseStorageMemory(800L) assert(mm.storageMemoryUsed === 200L) // Acquire after release assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired mm.releaseStorageMemory(100L) @@ -133,7 +135,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase assert(mm.acquireStorageMemory(dummyBlock, 50L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 50L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 200L) // Only execution memory should be released @@ -151,21 +153,25 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val dummyBlock = TestBlockId("lonely water") val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem) assert(mm.acquireUnrollMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + when(ms.currentUnrollMemory).thenReturn(100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 100L) mm.releaseUnrollMemory(40L) assert(mm.storageMemoryUsed === 60L) when(ms.currentUnrollMemory).thenReturn(60L) - assert(mm.acquireUnrollMemory(dummyBlock, 500L, evictedBlocks)) + assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 860L) // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. - // Since we already occupy 60 bytes, we will try to ensure only 400 - 60 = 340 bytes. - assertEnsureFreeSpaceCalled(ms, 340L) - assert(mm.storageMemoryUsed === 560L) - when(ms.currentUnrollMemory).thenReturn(560L) - assert(!mm.acquireUnrollMemory(dummyBlock, 800L, evictedBlocks)) - assert(mm.storageMemoryUsed === 560L) - // We already have 560 bytes > the max unroll space of 400 bytes, so no bytes are freed - assertEnsureFreeSpaceCalled(ms, 0L) + // Since we already occupy 60 bytes, we will try to evict only 400 - 60 = 340 bytes. + assert(mm.acquireUnrollMemory(dummyBlock, 240L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + when(ms.currentUnrollMemory).thenReturn(300L) // 60 + 240 + assert(mm.storageMemoryUsed === 1000L) + evictedBlocks.clear() + assert(!mm.acquireUnrollMemory(dummyBlock, 150L, evictedBlocks)) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 400 - 300 + assert(mm.storageMemoryUsed === 900L) // 100 bytes were evicted // Release beyond what was acquired mm.releaseUnrollMemory(maxStorageMem) assert(mm.storageMemoryUsed === 0L) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index e97c898a4478..71221deeb4c2 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.memory -import scala.collection.mutable.ArrayBuffer - import org.scalatest.PrivateMethodTester import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} +import org.apache.spark.storage.{MemoryStore, TestBlockId} class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTester { private val dummyBlock = TestBlockId("--") - private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] private val storageFraction: Double = 0.5 @@ -78,33 +75,40 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val (mm, ms) = makeThings(maxMemory) assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks)) - // `ensureFreeSpace` should be called with the number of bytes requested - assertEnsureFreeSpaceCalled(ms, 10L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 10L) + assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire more than the max, not granted assert(!mm.acquireStorageMemory(dummyBlock, maxMemory + 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, maxMemory + 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 110L) // Acquire up to the max, requests after this are still granted due to LRU eviction assert(mm.acquireStorageMemory(dummyBlock, maxMemory, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1000L) + assertEvictBlocksToFreeSpaceCalled(ms, 110L) assert(mm.storageMemoryUsed === 1000L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceCalled(ms, 1L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() + // Note: We evicted 1 byte to put another 1-byte block in, so the storage memory used remains at + // 1000 bytes. This is different from real behavior, where the 1-byte block would have evicted + // the 1000-byte block entirely. This is set up differently so we can write finer-grained tests. assert(mm.storageMemoryUsed === 1000L) mm.releaseStorageMemory(800L) assert(mm.storageMemoryUsed === 200L) // Acquire after release assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 201L) mm.releaseAllStorageMemory() assert(mm.storageMemoryUsed === 0L) assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 1L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 1L) // Release beyond what was acquired mm.releaseStorageMemory(100L) @@ -117,25 +121,27 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val (mm, ms) = makeThings(maxMemory) // Acquire enough storage memory to exceed the storage region assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 750L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) // Execution needs to request 250 bytes to evict storage memory assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) assert(mm.executionMemoryUsed === 100L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Execution wants 200 bytes but only 150 are free, so storage is evicted assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) assert(mm.executionMemoryUsed === 300L) - assertEnsureFreeSpaceCalled(ms, 50L) - assert(mm.executionMemoryUsed === 300L) + assert(mm.storageMemoryUsed === 700L) + assertEvictBlocksToFreeSpaceCalled(ms, 50L) + assert(evictedBlocks.nonEmpty) + evictedBlocks.clear() mm.releaseAllStorageMemory() require(mm.executionMemoryUsed === 300L) require(mm.storageMemoryUsed === 0, "bad test: all storage memory should have been released") // Acquire some storage memory again, but this time keep it within the storage region assert(mm.acquireStorageMemory(dummyBlock, 400L, evictedBlocks)) - assertEnsureFreeSpaceCalled(ms, 400L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 400L) assert(mm.executionMemoryUsed === 300L) // Execution cannot evict storage because the latter is within the storage fraction, @@ -143,7 +149,27 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.acquireExecutionMemory(400L, taskAttemptId, MemoryMode.ON_HEAP) === 300L) assert(mm.executionMemoryUsed === 600L) assert(mm.storageMemoryUsed === 400L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) + } + + test("execution memory requests smaller than free memory should evict storage (SPARK-12165)") { + val maxMemory = 1000L + val taskAttemptId = 0L + val (mm, ms) = makeThings(maxMemory) + // Acquire enough storage memory to exceed the storage region size + assert(mm.acquireStorageMemory(dummyBlock, 700L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.executionMemoryUsed === 0L) + assert(mm.storageMemoryUsed === 700L) + // SPARK-12165: previously, MemoryStore would not evict anything because it would + // mistakenly think that the 300 bytes of free space was still available even after + // using it to expand the execution pool. Consequently, no storage memory was released + // and the following call granted only 300 bytes to execution. + assert(mm.acquireExecutionMemory(500L, taskAttemptId, MemoryMode.ON_HEAP) === 500L) + assertEvictBlocksToFreeSpaceCalled(ms, 200L) + assert(mm.storageMemoryUsed === 500L) + assert(mm.executionMemoryUsed === 500L) + assert(evictedBlocks.nonEmpty) } test("storage does not evict execution") { @@ -154,32 +180,34 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.acquireExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) === 800L) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 0L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) - assertEnsureFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceNotCalled(ms) assert(!mm.acquireStorageMemory(dummyBlock, 250L, evictedBlocks)) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) - assertEnsureFreeSpaceCalled(ms, 250L) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) mm.releaseExecutionMemory(maxMemory, taskAttemptId, MemoryMode.ON_HEAP) mm.releaseStorageMemory(maxMemory) // Acquire some execution memory again, but this time keep it within the execution region assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 0L) - assertEnsureFreeSpaceNotCalled(ms) + assertEvictBlocksToFreeSpaceNotCalled(ms) // Storage should still not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceCalled(ms, 750L) + assertEvictBlocksToFreeSpaceNotCalled(ms) // since there were 800 bytes free assert(!mm.acquireStorageMemory(dummyBlock, 850L, evictedBlocks)) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 750L) - assertEnsureFreeSpaceCalled(ms, 850L) + // Do not attempt to evict blocks, since evicting will not free enough memory: + assertEvictBlocksToFreeSpaceNotCalled(ms) } test("small heap") { From 1eb7c22ce72a1b82ed194a51bbcf0da9c771605a Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 9 Dec 2015 19:47:38 +0000 Subject: [PATCH 1718/1824] [SPARK-11824][WEBUI] WebUI does not render descriptions with 'bad' HTML, throws console error Don't warn when description isn't valid HTML since it may properly be like "SELECT ... where foo <= 1" The tests for this code indicate that it's normal to handle strings like this that don't contain HTML as a string rather than markup. Hence logging every such instance as a warning is too noisy since it's not a problem. this is an issue for stages whose name contain SQL like the above CC tdas as author of this bit of code Author: Sean Owen Closes #10159 from srowen/SPARK-11824. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 1e8194f57888..81a6f07ec836 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -448,7 +448,6 @@ private[spark] object UIUtils extends Logging { new RuleTransformer(rule).transform(xml) } catch { case NonFatal(e) => - logWarning(s"Invalid job description: $desc ", e) {desc} } } From 051c6a066f7b5fcc7472412144c15b50a5319bd5 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 9 Dec 2015 12:00:48 -0800 Subject: [PATCH 1719/1824] [SPARK-11551][DOC] Replace example code in ml-features.md using include_example PR on behalf of somideshmukh, thanks! Author: Xusen Yin Author: somideshmukh Closes #10219 from yinxusen/SPARK-11551. --- docs/ml-features.md | 1112 +---------------- .../examples/ml/JavaBinarizerExample.java | 68 + .../examples/ml/JavaBucketizerExample.java | 71 ++ .../spark/examples/ml/JavaDCTExample.java | 65 + .../ml/JavaElementwiseProductExample.java | 75 ++ .../examples/ml/JavaMinMaxScalerExample.java | 51 + .../spark/examples/ml/JavaNGramExample.java | 71 ++ .../examples/ml/JavaNormalizerExample.java | 54 + .../examples/ml/JavaOneHotEncoderExample.java | 78 ++ .../spark/examples/ml/JavaPCAExample.java | 71 ++ .../ml/JavaPolynomialExpansionExample.java | 71 ++ .../examples/ml/JavaRFormulaExample.java | 69 + .../ml/JavaStandardScalerExample.java | 54 + .../ml/JavaStopWordsRemoverExample.java | 65 + .../examples/ml/JavaStringIndexerExample.java | 66 + .../examples/ml/JavaTokenizerExample.java | 75 ++ .../ml/JavaVectorAssemblerExample.java | 67 + .../examples/ml/JavaVectorIndexerExample.java | 61 + .../examples/ml/JavaVectorSlicerExample.java | 73 ++ .../src/main/python/ml/binarizer_example.py | 43 + .../src/main/python/ml/bucketizer_example.py | 43 + .../python/ml/elementwise_product_example.py | 39 + examples/src/main/python/ml/n_gram_example.py | 42 + .../src/main/python/ml/normalizer_example.py | 43 + .../main/python/ml/onehot_encoder_example.py | 48 + examples/src/main/python/ml/pca_example.py | 42 + .../python/ml/polynomial_expansion_example.py | 43 + .../src/main/python/ml/rformula_example.py | 44 + .../main/python/ml/standard_scaler_example.py | 43 + .../python/ml/stopwords_remover_example.py | 40 + .../main/python/ml/string_indexer_example.py | 39 + .../src/main/python/ml/tokenizer_example.py | 44 + .../python/ml/vector_assembler_example.py | 42 + .../main/python/ml/vector_indexer_example.py | 40 + .../spark/examples/ml/BinarizerExample.scala | 48 + .../spark/examples/ml/BucketizerExample.scala | 52 + .../apache/spark/examples/ml/DCTExample.scala | 54 + .../ml/ElementWiseProductExample.scala | 52 + .../examples/ml/MinMaxScalerExample.scala | 50 + .../spark/examples/ml/NGramExample.scala | 47 + .../spark/examples/ml/NormalizerExample.scala | 52 + .../examples/ml/OneHotEncoderExample.scala | 58 + .../apache/spark/examples/ml/PCAExample.scala | 53 + .../ml/PolynomialExpansionExample.scala | 51 + .../spark/examples/ml/RFormulaExample.scala | 49 + .../examples/ml/StandardScalerExample.scala | 52 + .../examples/ml/StopWordsRemoverExample.scala | 48 + .../examples/ml/StringIndexerExample.scala | 48 + .../spark/examples/ml/TokenizerExample.scala | 54 + .../examples/ml/VectorAssemblerExample.scala | 49 + .../examples/ml/VectorIndexerExample.scala | 54 + .../examples/ml/VectorSlicerExample.scala | 58 + 52 files changed, 2820 insertions(+), 1061 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java create mode 100644 examples/src/main/python/ml/binarizer_example.py create mode 100644 examples/src/main/python/ml/bucketizer_example.py create mode 100644 examples/src/main/python/ml/elementwise_product_example.py create mode 100644 examples/src/main/python/ml/n_gram_example.py create mode 100644 examples/src/main/python/ml/normalizer_example.py create mode 100644 examples/src/main/python/ml/onehot_encoder_example.py create mode 100644 examples/src/main/python/ml/pca_example.py create mode 100644 examples/src/main/python/ml/polynomial_expansion_example.py create mode 100644 examples/src/main/python/ml/rformula_example.py create mode 100644 examples/src/main/python/ml/standard_scaler_example.py create mode 100644 examples/src/main/python/ml/stopwords_remover_example.py create mode 100644 examples/src/main/python/ml/string_indexer_example.py create mode 100644 examples/src/main/python/ml/tokenizer_example.py create mode 100644 examples/src/main/python/ml/vector_assembler_example.py create mode 100644 examples/src/main/python/ml/vector_indexer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 55e401221917..7ad7c4eb7ea6 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -170,25 +170,7 @@ Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.fea and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{Tokenizer, RegexTokenizer} - -val sentenceDataFrame = sqlContext.createDataFrame(Seq( - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -)).toDF("label", "sentence") -val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") -val regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) - -val tokenized = tokenizer.transform(sentenceDataFrame) -tokenized.select("words", "label").take(3).foreach(println) -val regexTokenized = regexTokenizer.transform(sentenceDataFrame) -regexTokenized.select("words", "label").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %}
    @@ -197,44 +179,7 @@ Refer to the [Tokenizer Java docs](api/java/org/apache/spark/ml/feature/Tokenize and the [RegexTokenizer Java docs](api/java/org/apache/spark/ml/feature/RegexTokenizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RegexTokenizer; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(1, "I wish Java could use case classes"), - RowFactory.create(2, "Logistic,regression,models,are,neat") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); -Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); -DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); -for (Row r : wordsDataFrame.select("words", "label").take(3)) { - java.util.List words = r.getList(0); - for (String word : words) System.out.print(word + " "); - System.out.println(); -} - -RegexTokenizer regexTokenizer = new RegexTokenizer() - .setInputCol("sentence") - .setOutputCol("words") - .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaTokenizerExample.java %}
    @@ -243,21 +188,7 @@ Refer to the [Tokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.featu the the [RegexTokenizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RegexTokenizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Tokenizer, RegexTokenizer - -sentenceDataFrame = sqlContext.createDataFrame([ - (0, "Hi I heard about Spark"), - (1, "I wish Java could use case classes"), - (2, "Logistic,regression,models,are,neat") -], ["label", "sentence"]) -tokenizer = Tokenizer(inputCol="sentence", outputCol="words") -wordsDataFrame = tokenizer.transform(sentenceDataFrame) -for words_label in wordsDataFrame.select("words", "label").take(3): - print(words_label) -regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") -# alternatively, pattern="\\w+", gaps(False) -{% endhighlight %} +{% include_example python/ml/tokenizer_example.py %}
    @@ -306,19 +237,7 @@ filtered out. Refer to the [StopWordsRemover Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StopWordsRemover - -val remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered") -val dataSet = sqlContext.createDataFrame(Seq( - (0, Seq("I", "saw", "the", "red", "baloon")), - (1, Seq("Mary", "had", "a", "little", "lamb")) -)).toDF("id", "raw") - -remover.transform(dataSet).show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala %}
    @@ -326,34 +245,7 @@ remover.transform(dataSet).show() Refer to the [StopWordsRemover Java docs](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -StopWordsRemover remover = new StopWordsRemover() - .setInputCol("raw") - .setOutputCol("filtered"); - -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), - RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) -)); -StructType schema = new StructType(new StructField[] { - new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame dataset = jsql.createDataFrame(rdd, schema); - -remover.transform(dataset).show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java %}
    @@ -361,17 +253,7 @@ remover.transform(dataset).show(); Refer to the [StopWordsRemover Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StopWordsRemover) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StopWordsRemover - -sentenceData = sqlContext.createDataFrame([ - (0, ["I", "saw", "the", "red", "baloon"]), - (1, ["Mary", "had", "a", "little", "lamb"]) -], ["label", "raw"]) - -remover = StopWordsRemover(inputCol="raw", outputCol="filtered") -remover.transform(sentenceData).show(truncate=False) -{% endhighlight %} +{% include_example python/ml/stopwords_remover_example.py %}
    @@ -388,19 +270,7 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t Refer to the [NGram Scala docs](api/scala/index.html#org.apache.spark.ml.feature.NGram) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.NGram - -val wordDataFrame = sqlContext.createDataFrame(Seq( - (0, Array("Hi", "I", "heard", "about", "Spark")), - (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), - (2, Array("Logistic", "regression", "models", "are", "neat")) -)).toDF("label", "words") - -val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") -val ngramDataFrame = ngram.transform(wordDataFrame) -ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NGramExample.scala %}
    @@ -408,38 +278,7 @@ ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(pri Refer to the [NGram Java docs](api/java/org/apache/spark/ml/feature/NGram.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.NGram; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), - RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), - RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) -}); -DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); -NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); -DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); -for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { - java.util.List ngrams = r.getList(0); - for (String ngram : ngrams) System.out.print(ngram + " --- "); - System.out.println(); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNGramExample.java %}
    @@ -447,19 +286,7 @@ for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { Refer to the [NGram Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.NGram) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import NGram - -wordDataFrame = sqlContext.createDataFrame([ - (0, ["Hi", "I", "heard", "about", "Spark"]), - (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), - (2, ["Logistic", "regression", "models", "are", "neat"]) -], ["label", "words"]) -ngram = NGram(inputCol="words", outputCol="ngrams") -ngramDataFrame = ngram.transform(wordDataFrame) -for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): - print(ngrams_label) -{% endhighlight %} +{% include_example python/ml/n_gram_example.py %}
    @@ -476,26 +303,7 @@ Binarization is the process of thresholding numerical features to binary (0/1) f Refer to the [Binarizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Binarizer -import org.apache.spark.sql.DataFrame - -val data = Array( - (0, 0.1), - (1, 0.8), - (2, 0.2) -) -val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") - -val binarizer: Binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5) - -val binarizedDataFrame = binarizer.transform(dataFrame) -val binarizedFeatures = binarizedDataFrame.select("binarized_feature") -binarizedFeatures.collect().foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BinarizerExample.scala %}
    @@ -503,40 +311,7 @@ binarizedFeatures.collect().foreach(println) Refer to the [Binarizer Java docs](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.Binarizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, 0.1), - RowFactory.create(1, 0.8), - RowFactory.create(2, 0.2) -)); -StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), - new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); -Binarizer binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") - .setThreshold(0.5); -DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); -DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); -for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); - System.out.println(binarized_value); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBinarizerExample.java %}
    @@ -544,20 +319,7 @@ for (Row r : binarizedFeatures.collect()) { Refer to the [Binarizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Binarizer - -continuousDataFrame = sqlContext.createDataFrame([ - (0, 0.1), - (1, 0.8), - (2, 0.2) -], ["label", "feature"]) -binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") -binarizedDataFrame = binarizer.transform(continuousDataFrame) -binarizedFeatures = binarizedDataFrame.select("binarized_feature") -for binarized_feature, in binarizedFeatures.collect(): - print(binarized_feature) -{% endhighlight %} +{% include_example python/ml/binarizer_example.py %}
    @@ -571,25 +333,7 @@ for binarized_feature, in binarizedFeatures.collect(): Refer to the [PCA Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PCA) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df) -val pcaDF = pca.transform(df) -val result = pcaDF.select("pcaFeatures") -result.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PCAExample.scala %}
    @@ -597,42 +341,7 @@ result.show() Refer to the [PCA Java docs](api/java/org/apache/spark/ml/feature/PCA.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.PCA -import org.apache.spark.ml.feature.PCAModel -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), - RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -PCAModel pca = new PCA() - .setInputCol("features") - .setOutputCol("pcaFeatures") - .setK(3) - .fit(df); -DataFrame result = pca.transform(df).select("pcaFeatures"); -result.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPCAExample.java %}
    @@ -640,19 +349,7 @@ result.show(); Refer to the [PCA Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PCA) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), - (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), - (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] -df = sqlContext.createDataFrame(data,["features"]) -pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") -model = pca.fit(df) -result = model.transform(df).select("pcaFeatures") -result.show(truncate=False) -{% endhighlight %} +{% include_example python/ml/pca_example.py %}
    @@ -666,23 +363,7 @@ result.show(truncate=False) Refer to the [PolynomialExpansion Scala docs](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors - -val data = Array( - Vectors.dense(-2.0, 2.3), - Vectors.dense(0.0, 0.0), - Vectors.dense(0.6, -1.1) -) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val polynomialExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3) -val polyDF = polynomialExpansion.transform(df) -polyDF.select("polyFeatures").take(3).foreach(println) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala %}
    @@ -690,43 +371,7 @@ polyDF.select("polyFeatures").take(3).foreach(println) Refer to the [PolynomialExpansion Java docs](api/java/org/apache/spark/ml/feature/PolynomialExpansion.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaSparkContext jsc = ... -SQLContext jsql = ... -PolynomialExpansion polyExpansion = new PolynomialExpansion() - .setInputCol("features") - .setOutputCol("polyFeatures") - .setDegree(3); -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(-2.0, 2.3)), - RowFactory.create(Vectors.dense(0.0, 0.0)), - RowFactory.create(Vectors.dense(0.6, -1.1)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DataFrame polyDF = polyExpansion.transform(df); -Row[] row = polyDF.select("polyFeatures").take(3); -for (Row r : row) { - System.out.println(r.get(0)); -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java %}
    @@ -734,20 +379,7 @@ for (Row r : row) { Refer to the [PolynomialExpansion Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.PolynomialExpansion) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors - -df = sqlContext.createDataFrame( - [(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], - ["features"]) -px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") -polyDF = px.transform(df) -for expanded in polyDF.select("polyFeatures").take(3): - print(expanded) -{% endhighlight %} +{% include_example python/ml/polynomial_expansion_example.py %}
    @@ -771,22 +403,7 @@ $0$th DCT coefficient and _not_ the $N/2$th). Refer to the [DCT Scala docs](api/scala/index.html#org.apache.spark.ml.feature.DCT) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors - -val data = Seq( - Vectors.dense(0.0, 1.0, -2.0, 3.0), - Vectors.dense(-1.0, 2.0, 4.0, -7.0), - Vectors.dense(14.0, -2.0, -5.0, 1.0)) -val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") -val dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false) -val dctDf = dct.transform(df) -dctDf.select("featuresDCT").show(3) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DCTExample.scala %}
    @@ -794,39 +411,7 @@ dctDf.select("featuresDCT").show(3) Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), - RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), - RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", new VectorUDT(), false, Metadata.empty()), -}); -DataFrame df = jsql.createDataFrame(data, schema); -DCT dct = new DCT() - .setInputCol("features") - .setOutputCol("featuresDCT") - .setInverse(false); -DataFrame dctDf = dct.transform(df); -dctDf.select("featuresDCT").show(3); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}}
    @@ -881,18 +466,7 @@ index `2`. Refer to the [StringIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StringIndexer - -val df = sqlContext.createDataFrame( - Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) -).toDF("id", "category") -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") -val indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StringIndexerExample.scala %}
    @@ -900,37 +474,7 @@ indexed.show() Refer to the [StringIndexer Java docs](api/java/org/apache/spark/ml/feature/StringIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[] { - createStructField("id", DoubleType, false), - createStructField("category", StringType, false) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexer indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex"); -DataFrame indexed = indexer.fit(df).transform(df); -indexed.show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStringIndexerExample.java %}
    @@ -938,16 +482,7 @@ indexed.show(); Refer to the [StringIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StringIndexer - -df = sqlContext.createDataFrame( - [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], - ["id", "category"]) -indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -indexed = indexer.fit(df).transform(df) -indexed.show() -{% endhighlight %} +{% include_example python/ml/string_indexer_example.py %}
    @@ -1030,30 +565,7 @@ for more details on the API. Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} - -val df = sqlContext.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -)).toDF("id", "category") - -val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) -val indexed = indexer.transform(df) - -val encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec") -val encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
    @@ -1061,46 +573,7 @@ encoded.select("id", "categoryVec").show() Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") -)); -StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) -}); -DataFrame df = sqlContext.createDataFrame(jrdd, schema); -StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); -DataFrame indexed = indexer.transform(df); - -OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); -DataFrame encoded = encoder.transform(indexed); -encoded.select("id", "categoryVec").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
    @@ -1108,25 +581,7 @@ encoded.select("id", "categoryVec").show(); Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import OneHotEncoder, StringIndexer - -df = sqlContext.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") -], ["id", "category"]) - -stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") -model = stringIndexer.fit(df) -indexed = model.transform(df) -encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec") -encoded = encoder.transform(indexed) -encoded.select("id", "categoryVec").show() -{% endhighlight %} +{% include_example python/ml/onehot_encoder_example.py %}
    @@ -1150,23 +605,7 @@ In the example below, we read in a dataset of labeled points and then use `Vecto Refer to the [VectorIndexer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorIndexer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.VectorIndexer - -val data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10) -val indexerModel = indexer.fit(data) -val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet -println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) - -// Create new column "indexed" with categorical values transformed to indices -val indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorIndexerExample.scala %}
    @@ -1174,30 +613,7 @@ val indexedData = indexerModel.transform(data) Refer to the [VectorIndexer Java docs](api/java/org/apache/spark/ml/feature/VectorIndexer.html) for more details on the API. -{% highlight java %} -import java.util.Map; - -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame data = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -VectorIndexer indexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexed") - .setMaxCategories(10); -VectorIndexerModel indexerModel = indexer.fit(data); -Map> categoryMaps = indexerModel.javaCategoryMaps(); -System.out.print("Chose " + categoryMaps.size() + "categorical features:"); -for (Integer feature : categoryMaps.keySet()) { - System.out.print(" " + feature); -} -System.out.println(); - -// Create new column "indexed" with categorical values transformed to indices -DataFrame indexedData = indexerModel.transform(data); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java %}
    @@ -1205,17 +621,7 @@ DataFrame indexedData = indexerModel.transform(data); Refer to the [VectorIndexer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorIndexer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import VectorIndexer - -data = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) -indexerModel = indexer.fit(data) - -# Create new column "indexed" with categorical values transformed to indices -indexedData = indexerModel.transform(data) -{% endhighlight %} +{% include_example python/ml/vector_indexer_example.py %}
    @@ -1232,22 +638,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Normalizer - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -// Normalize each Vector using $L^1$ norm. -val normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0) -val l1NormData = normalizer.transform(dataFrame) - -// Normalize each Vector using $L^\infty$ norm. -val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    @@ -1255,24 +646,7 @@ val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.Positi Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.Normalizer; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); - -// Normalize each Vector using $L^1$ norm. -Normalizer normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normFeatures") - .setP(1.0); -DataFrame l1NormData = normalizer.transform(dataFrame); - -// Normalize each Vector using $L^\infty$ norm. -DataFrame lInfNormData = - normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    @@ -1280,19 +654,7 @@ DataFrame lInfNormData = Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Normalizer - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") - -# Normalize each Vector using $L^1$ norm. -normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) -l1NormData = normalizer.transform(dataFrame) - -# Normalize each Vector using $L^\infty$ norm. -lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) -{% endhighlight %} +{% include_example python/ml/normalizer_example.py %}
    @@ -1316,23 +678,7 @@ The following example demonstrates how to load a dataset in libsvm format and th Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.StandardScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false) - -// Compute summary statistics by fitting the StandardScaler -val scalerModel = scaler.fit(dataFrame) - -// Normalize each feature to have unit standard deviation. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    @@ -1340,25 +686,7 @@ val scaledData = scalerModel.transform(dataFrame) Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. -{% highlight java %} -import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -StandardScaler scaler = new StandardScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - .setWithStd(true) - .setWithMean(false); - -// Compute summary statistics by fitting the StandardScaler -StandardScalerModel scalerModel = scaler.fit(dataFrame); - -// Normalize each feature to have unit standard deviation. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    @@ -1366,20 +694,7 @@ DataFrame scaledData = scalerModel.transform(dataFrame); Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import StandardScaler - -dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", - withStd=True, withMean=False) - -# Compute summary statistics by fitting the StandardScaler -scalerModel = scaler.fit(dataFrame) - -# Normalize each feature to have unit standard deviation. -scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/standard_scaler_example.py %}
    @@ -1409,21 +724,7 @@ Refer to the [MinMaxScaler Scala docs](api/scala/index.html#org.apache.spark.ml. and the [MinMaxScalerModel Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.MinMaxScaler - -val dataFrame = sqlContext.read.format("libsvm") - .load("data/mllib/sample_libsvm_data.txt") -val scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures") - -// Compute summary statistics and generate MinMaxScalerModel -val scalerModel = scaler.fit(dataFrame) - -// rescale each feature to range [min, max]. -val scaledData = scalerModel.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala %}
    @@ -1432,24 +733,7 @@ Refer to the [MinMaxScaler Java docs](api/java/org/apache/spark/ml/feature/MinMa and the [MinMaxScalerModel Java docs](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html) for more details on the API. -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.MinMaxScaler; -import org.apache.spark.ml.feature.MinMaxScalerModel; -import org.apache.spark.sql.DataFrame; - -DataFrame dataFrame = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); -MinMaxScaler scaler = new MinMaxScaler() - .setInputCol("features") - .setOutputCol("scaledFeatures"); - -// Compute summary statistics and generate MinMaxScalerModel -MinMaxScalerModel scalerModel = scaler.fit(dataFrame); - -// rescale each feature to range [min, max]. -DataFrame scaledData = scalerModel.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %}
    @@ -1473,23 +757,7 @@ The following example demonstrates how to bucketize a column of `Double`s into a Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.Bucketizer -import org.apache.spark.sql.DataFrame - -val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) - -val data = Array(-0.5, -0.3, 0.0, 0.2) -val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") - -val bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits) - -// Transform original data into its bucket index. -val bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    @@ -1497,38 +765,7 @@ val bucketedData = bucketizer.transform(dataFrame) Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - -JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(-0.5), - RowFactory.create(-0.3), - RowFactory.create(0.0), - RowFactory.create(0.2) -)); -StructType schema = new StructType(new StructField[] { - new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) -}); -DataFrame dataFrame = jsql.createDataFrame(data, schema); - -Bucketizer bucketizer = new Bucketizer() - .setInputCol("features") - .setOutputCol("bucketedFeatures") - .setSplits(splits); - -// Transform original data into its bucket index. -DataFrame bucketedData = bucketizer.transform(dataFrame); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    @@ -1536,19 +773,7 @@ DataFrame bucketedData = bucketizer.transform(dataFrame); Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import Bucketizer - -splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] - -data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] -dataFrame = sqlContext.createDataFrame(data, ["features"]) - -bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") - -# Transform original data into its bucket index. -bucketedData = bucketizer.transform(dataFrame) -{% endhighlight %} +{% include_example python/ml/bucketizer_example.py %}
    @@ -1580,25 +805,7 @@ This example below demonstrates how to transform vectors using a transforming ve Refer to the [ElementwiseProduct Scala docs](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors - -// Create some vector data; also works for sparse vectors -val dataFrame = sqlContext.createDataFrame(Seq( - ("a", Vectors.dense(1.0, 2.0, 3.0)), - ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") - -val transformingVector = Vectors.dense(0.0, 1.0, 2.0) -val transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector") - -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show() - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala %}
    @@ -1606,41 +813,7 @@ transformer.transform(dataFrame).show() Refer to the [ElementwiseProduct Java docs](api/java/org/apache/spark/ml/feature/ElementwiseProduct.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -// Create some vector data; also works for sparse vectors -JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), - RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) -)); -List fields = new ArrayList(2); -fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); -fields.add(DataTypes.createStructField("vector", DataTypes.StringType, false)); -StructType schema = DataTypes.createStructType(fields); -DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); -Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); -ElementwiseProduct transformer = new ElementwiseProduct() - .setScalingVec(transformingVector) - .setInputCol("vector") - .setOutputCol("transformedVector"); -// Batch transform the vectors to create new column: -transformer.transform(dataFrame).show(); - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java %}
    @@ -1648,19 +821,8 @@ transformer.transform(dataFrame).show(); Refer to the [ElementwiseProduct Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ElementwiseProduct) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors - -data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] -df = sqlContext.createDataFrame(data, ["vector"]) -transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), - inputCol="vector", outputCol="transformedVector") -transformer.transform(df).show() - -{% endhighlight %} +{% include_example python/ml/elementwise_product_example.py %}
    - ## SQLTransformer @@ -1763,19 +925,7 @@ output column to `features`, after transformation we should get the following Da Refer to the [VectorAssembler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.feature.VectorAssembler - -val dataset = sqlContext.createDataFrame( - Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) -).toDF("id", "hour", "mobile", "userFeatures", "clicked") -val assembler = new VectorAssembler() - .setInputCols(Array("hour", "mobile", "userFeatures")) - .setOutputCol("features") -val output = assembler.transform(dataset) -println(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala %}
    @@ -1783,36 +933,7 @@ println(output.select("features", "clicked").first()) Refer to the [VectorAssembler Java docs](api/java/org/apache/spark/ml/feature/VectorAssembler.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("hour", IntegerType, false), - createStructField("mobile", DoubleType, false), - createStructField("userFeatures", new VectorUDT(), false), - createStructField("clicked", DoubleType, false) -}); -Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); -JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"hour", "mobile", "userFeatures"}) - .setOutputCol("features"); - -DataFrame output = assembler.transform(dataset); -System.out.println(output.select("features", "clicked").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java %}
    @@ -1820,19 +941,7 @@ System.out.println(output.select("features", "clicked").first()); Refer to the [VectorAssembler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) for more details on the API. -{% highlight python %} -from pyspark.mllib.linalg import Vectors -from pyspark.ml.feature import VectorAssembler - -dataset = sqlContext.createDataFrame( - [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], - ["id", "hour", "mobile", "userFeatures", "clicked"]) -assembler = VectorAssembler( - inputCols=["hour", "mobile", "userFeatures"], - outputCol="features") -output = assembler.transform(dataset) -print(output.select("features", "clicked").first()) -{% endhighlight %} +{% include_example python/ml/vector_assembler_example.py %}
    @@ -1962,33 +1071,7 @@ Suppose also that we have a potential input attributes for the `userFeatures`, i Refer to the [VectorSlicer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) for more details on the API. -{% highlight scala %} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} -import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - -val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.dense(-2.0, 2.3, 0.0) -) - -val defaultAttr = NumericAttribute.defaultAttr -val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) -val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - -val dataRDD = sc.parallelize(data).map(Row.apply) -val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) - -val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") - -slicer.setIndices(1).setNames("f3") -// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) - -val output = slicer.transform(dataset) -println(output.select("userFeatures", "features").first()) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/VectorSlicerExample.scala %}
    @@ -1996,41 +1079,7 @@ println(output.select("userFeatures", "features").first()) Refer to the [VectorSlicer Java docs](api/java/org/apache/spark/ml/feature/VectorSlicer.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -Attribute[] attrs = new Attribute[]{ - NumericAttribute.defaultAttr().withName("f1"), - NumericAttribute.defaultAttr().withName("f2"), - NumericAttribute.defaultAttr().withName("f3") -}; -AttributeGroup group = new AttributeGroup("userFeatures", attrs); - -JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), - RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) -)); - -DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); - -VectorSlicer vectorSlicer = new VectorSlicer() - .setInputCol("userFeatures").setOutputCol("features"); - -vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); -// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - -DataFrame output = vectorSlicer.transform(dataset); - -System.out.println(output.select("userFeatures", "features").first()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java %}
    @@ -2067,21 +1116,7 @@ id | country | hour | clicked | features | label Refer to the [RFormula Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RFormula) for more details on the API. -{% highlight scala %} -import org.apache.spark.ml.feature.RFormula - -val dataset = sqlContext.createDataFrame(Seq( - (7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0) -)).toDF("id", "country", "hour", "clicked") -val formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label") -val output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/RFormulaExample.scala %}
    @@ -2089,38 +1124,7 @@ output.select("features", "label").show() Refer to the [RFormula Java docs](api/java/org/apache/spark/ml/feature/RFormula.html) for more details on the API. -{% highlight java %} -import java.util.Arrays; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.feature.RFormula; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -StructType schema = createStructType(new StructField[] { - createStructField("id", IntegerType, false), - createStructField("country", StringType, false), - createStructField("hour", IntegerType, false), - createStructField("clicked", DoubleType, false) -}); -JavaRDD rdd = jsc.parallelize(Arrays.asList( - RowFactory.create(7, "US", 18, 1.0), - RowFactory.create(8, "CA", 12, 0.0), - RowFactory.create(9, "NZ", 15, 0.0) -)); -DataFrame dataset = sqlContext.createDataFrame(rdd, schema); - -RFormula formula = new RFormula() - .setFormula("clicked ~ country + hour") - .setFeaturesCol("features") - .setLabelCol("label"); - -DataFrame output = formula.fit(dataset).transform(dataset); -output.select("features", "label").show(); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaRFormulaExample.java %}
    @@ -2128,21 +1132,7 @@ output.select("features", "label").show(); Refer to the [RFormula Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) for more details on the API. -{% highlight python %} -from pyspark.ml.feature import RFormula - -dataset = sqlContext.createDataFrame( - [(7, "US", 18, 1.0), - (8, "CA", 12, 0.0), - (9, "NZ", 15, 0.0)], - ["id", "country", "hour", "clicked"]) -formula = RFormula( - formula="clicked ~ country + hour", - featuresCol="features", - labelCol="label") -output = formula.fit(dataset).transform(dataset) -output.select("features", "label").show() -{% endhighlight %} +{% include_example python/ml/rformula_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java new file mode 100644 index 000000000000..9698cac50437 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Binarizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBinarizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, 0.1), + RowFactory.create(1, 0.8), + RowFactory.create(2, 0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema); + Binarizer binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5); + DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); + DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); + for (Row r : binarizedFeatures.collect()) { + Double binarized_value = r.getDouble(0); + System.out.println(binarized_value); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java new file mode 100644 index 000000000000..8ad369cc93e8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.Bucketizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBucketizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataFrame = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits); + + // Transform original data into its bucket index. + DataFrame bucketedData = bucketizer.transform(dataFrame); + bucketedData.show(); + // $example off$ + jsc.stop(); + } +} + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java new file mode 100644 index 000000000000..35c0d534a45e --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.DCT; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaDCTExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), + RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), + RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + DataFrame df = jsql.createDataFrame(data, schema); + DCT dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false); + DataFrame dctDf = dct.transform(df); + dctDf.select("featuresDCT").show(3); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java new file mode 100644 index 000000000000..2898accec61b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaElementwiseProductExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Create some vector data; also works for sparse vectors + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), + RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) + )); + + List fields = new ArrayList(2); + fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); + fields.add(DataTypes.createStructField("vector", new VectorUDT(), false)); + + StructType schema = DataTypes.createStructType(fields); + + DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema); + + Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); + + ElementwiseProduct transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector"); + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java new file mode 100644 index 000000000000..2d50ba7faa1a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaMinMaxScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + + // Compute summary statistics and generate MinMaxScalerModel + MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + + // rescale each feature to range [min, max]. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java new file mode 100644 index 000000000000..8fd75ed8b5f4 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.NGram; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaNGramExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNGramExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")), + RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")), + RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField( + "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema); + + NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams"); + + DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame); + + for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) { + java.util.List ngrams = r.getList(0); + for (String ngram : ngrams) System.out.print(ngram + " --- "); + System.out.println(); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java new file mode 100644 index 000000000000..ed3f6163c055 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.Normalizer; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaNormalizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Normalize each Vector using $L^1$ norm. + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0); + + DataFrame l1NormData = normalizer.transform(dataFrame); + l1NormData.show(); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java new file mode 100644 index 000000000000..bc509607084b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.OneHotEncoder; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaOneHotEncoderExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("category", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + + StringIndexerModel indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df); + DataFrame indexed = indexer.transform(df); + + OneHotEncoder encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec"); + DataFrame encoded = encoder.transform(indexed); + encoded.select("id", "categoryVec").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java new file mode 100644 index 000000000000..8282fab084f3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PCA; +import org.apache.spark.ml.feature.PCAModel; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPCAExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), + RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + + PCAModel pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df); + + DataFrame result = pca.transform(df).select("pcaFeatures"); + result.show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java new file mode 100644 index 000000000000..668f71e64056 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.PolynomialExpansion; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPolynomialExpansionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Arrays.asList( + RowFactory.create(Vectors.dense(-2.0, 2.3)), + RowFactory.create(Vectors.dense(0.0, 0.0)), + RowFactory.create(Vectors.dense(0.6, -1.1)) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + }); + + DataFrame df = jsql.createDataFrame(data, schema); + DataFrame polyDF = polyExpansion.transform(df); + + Row[] row = polyDF.select("polyFeatures").take(3); + for (Row r : row) { + System.out.println(r.get(0)); + } + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java new file mode 100644 index 000000000000..1e1062b541ad --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaRFormulaExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) + }); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) + )); + + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + DataFrame output = formula.fit(dataset).transform(dataset); + output.select("features", "label").show(); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java new file mode 100644 index 000000000000..da4756643f3c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaStandardScalerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java new file mode 100644 index 000000000000..b6b201c6b68d --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaStopWordsRemoverExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + + StructType schema = new StructType(new StructField[]{ + new StructField( + "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(rdd, schema); + remover.transform(dataset).show(); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java new file mode 100644 index 000000000000..05d12c1e702f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaStringIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") + )); + StructType schema = new StructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("category", StringType, false) + }); + DataFrame df = sqlContext.createDataFrame(jrdd, schema); + StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); + DataFrame indexed = indexer.fit(df).transform(df); + indexed.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java new file mode 100644 index 000000000000..617dc3f66e3b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RegexTokenizer; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaTokenizerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "Hi I heard about Spark"), + RowFactory.create(1, "I wish Java could use case classes"), + RowFactory.create(2, "Logistic,regression,models,are,neat") + )); + + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + + Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); + + DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + for (Row r : wordsDataFrame.select("words", "label"). take(3)) { + java.util.List words = r.getList(0); + for (String word : words) System.out.print(word + " "); + System.out.println(); + } + + RegexTokenizer regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); + // $example off$ + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java new file mode 100644 index 000000000000..7e230b5897c1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorAssemblerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + DataFrame output = assembler.transform(dataset); + System.out.println(output.select("features", "clicked").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java new file mode 100644 index 000000000000..545758e31d97 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import java.util.Map; + +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.sql.DataFrame; +// $example off$ + +public class JavaVectorIndexerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + VectorIndexer indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10); + VectorIndexerModel indexerModel = indexer.fit(data); + + Map> categoryMaps = indexerModel.javaCategoryMaps(); + System.out.print("Chose " + categoryMaps.size() + " categorical features:"); + + for (Integer feature : categoryMaps.keySet()) { + System.out.print(" " + feature); + } + System.out.println(); + + // Create new column "indexed" with categorical values transformed to indices + DataFrame indexedData = indexerModel.transform(data); + indexedData.show(); + // $example off$ + jsc.stop(); + } +} \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java new file mode 100644 index 000000000000..4d5cb04ff5e2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; + +// $example on$ +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.ml.feature.VectorSlicer; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaVectorSlicerExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + + DataFrame output = vectorSlicer.transform(dataset); + + System.out.println(output.select("userFeatures", "features").first()); + // $example off$ + jsc.stop(); + } +} + diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py new file mode 100644 index 000000000000..317cfa638a5a --- /dev/null +++ b/examples/src/main/python/ml/binarizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Binarizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinarizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + continuousDataFrame = sqlContext.createDataFrame([ + (0, 0.1), + (1, 0.8), + (2, 0.2) + ], ["label", "feature"]) + binarizer = Binarizer(threshold=0.5, inputCol="feature", outputCol="binarized_feature") + binarizedDataFrame = binarizer.transform(continuousDataFrame) + binarizedFeatures = binarizedDataFrame.select("binarized_feature") + for binarized_feature, in binarizedFeatures.collect(): + print(binarized_feature) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py new file mode 100644 index 000000000000..4304255f350d --- /dev/null +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Bucketizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BucketizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] + + data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] + dataFrame = sqlContext.createDataFrame(data, ["features"]) + + bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") + + # Transform original data into its bucket index. + bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py new file mode 100644 index 000000000000..c85cb0d89543 --- /dev/null +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import ElementwiseProduct +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ElementwiseProductExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] + df = sqlContext.createDataFrame(data, ["vector"]) + transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), + inputCol="vector", outputCol="transformedVector") + transformer.transform(df).show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py new file mode 100644 index 000000000000..f2d85f53e721 --- /dev/null +++ b/examples/src/main/python/ml/n_gram_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import NGram +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NGramExample") + sqlContext = SQLContext(sc) + + # $example on$ + wordDataFrame = sqlContext.createDataFrame([ + (0, ["Hi", "I", "heard", "about", "Spark"]), + (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), + (2, ["Logistic", "regression", "models", "are", "neat"]) + ], ["label", "words"]) + ngram = NGram(inputCol="words", outputCol="ngrams") + ngramDataFrame = ngram.transform(wordDataFrame) + for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): + print(ngrams_label) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py new file mode 100644 index 000000000000..d490221474c2 --- /dev/null +++ b/examples/src/main/python/ml/normalizer_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Normalizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="NormalizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Normalize each Vector using $L^1$ norm. + normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) + l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + # Normalize each Vector using $L^\infty$ norm. + lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")}) + lInfNormData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py new file mode 100644 index 000000000000..0f94c26638d3 --- /dev/null +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import OneHotEncoder, StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="OneHotEncoderExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + ], ["id", "category"]) + + stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + model = stringIndexer.fit(df) + indexed = model.transform(df) + encoder = OneHotEncoder(dropLast=False, inputCol="categoryIndex", outputCol="categoryVec") + encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py new file mode 100644 index 000000000000..a17181f1b8a5 --- /dev/null +++ b/examples/src/main/python/ml/pca_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PCA +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PCAExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), + (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), + (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] + df = sqlContext.createDataFrame(data, ["features"]) + pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") + model = pca.fit(df) + result = model.transform(df).select("pcaFeatures") + result.show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py new file mode 100644 index 000000000000..3d4fafd1a42e --- /dev/null +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import PolynomialExpansion +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PolynomialExpansionExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext\ + .createDataFrame([(Vectors.dense([-2.0, 2.3]), ), + (Vectors.dense([0.0, 0.0]), ), + (Vectors.dense([0.6, -1.1]), )], + ["features"]) + px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") + polyDF = px.transform(df) + for expanded in polyDF.select("polyFeatures").take(3): + print(expanded) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py new file mode 100644 index 000000000000..b544a1470076 --- /dev/null +++ b/examples/src/main/python/ml/rformula_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import RFormula +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="RFormulaExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) + formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") + output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py new file mode 100644 index 000000000000..ae7aa85005bc --- /dev/null +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StandardScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StandardScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", + withStd=True, withMean=False) + + # Compute summary statistics by fitting the StandardScaler + scalerModel = scaler.fit(dataFrame) + + # Normalize each feature to have unit standard deviation. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py new file mode 100644 index 000000000000..01f94af8ca75 --- /dev/null +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StopWordsRemover +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StopWordsRemoverExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceData = sqlContext.createDataFrame([ + (0, ["I", "saw", "the", "red", "baloon"]), + (1, ["Mary", "had", "a", "little", "lamb"]) + ], ["label", "raw"]) + + remover = StopWordsRemover(inputCol="raw", outputCol="filtered") + remover.transform(sentenceData).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py new file mode 100644 index 000000000000..58a8cb5d56b7 --- /dev/null +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import StringIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="StringIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) + indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") + indexed = indexer.fit(df).transform(df) + indexed.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py new file mode 100644 index 000000000000..ce9b225be535 --- /dev/null +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import Tokenizer, RegexTokenizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="TokenizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + sentenceDataFrame = sqlContext.createDataFrame([ + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + ], ["label", "sentence"]) + tokenizer = Tokenizer(inputCol="sentence", outputCol="words") + wordsDataFrame = tokenizer.transform(sentenceDataFrame) + for words_label in wordsDataFrame.select("words", "label").take(3): + print(words_label) + regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern="\\W") + # alternatively, pattern="\\w+", gaps(False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py new file mode 100644 index 000000000000..04f64839f188 --- /dev/null +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.ml.feature import VectorAssembler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorAssemblerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataset = sqlContext.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + output = assembler.transform(dataset) + print(output.select("features", "clicked").first()) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py new file mode 100644 index 000000000000..146f41c1dd90 --- /dev/null +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import VectorIndexer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="VectorIndexerExample") + sqlContext = SQLContext(sc) + + # $example on$ + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) + indexerModel = indexer.fit(data) + + # Create new column "indexed" with categorical values transformed to indices + indexedData = indexerModel.transform(data) + indexedData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala new file mode 100644 index 000000000000..e724aa587294 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Binarizer +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkConf, SparkContext} + +object BinarizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BinarizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + + val binarizer: Binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.5) + + val binarizedDataFrame = binarizer.transform(dataFrame) + val binarizedFeatures = binarizedDataFrame.select("binarized_feature") + binarizedFeatures.collect().foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala new file mode 100644 index 000000000000..7c75e3d72b47 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Bucketizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object BucketizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("BucketizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + + val data = Array(-0.5, -0.3, 0.0, 0.2) + val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val bucketizer = new Bucketizer() + .setInputCol("features") + .setOutputCol("bucketedFeatures") + .setSplits(splits) + + // Transform original data into its bucket index. + val bucketedData = bucketizer.transform(dataFrame) + bucketedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala new file mode 100644 index 000000000000..314c2c28a2a1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.DCT +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object DCTExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DCTExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Seq( + Vectors.dense(0.0, 1.0, -2.0, 3.0), + Vectors.dense(-1.0, 2.0, 4.0, -7.0), + Vectors.dense(14.0, -2.0, -5.0, 1.0)) + + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + val dct = new DCT() + .setInputCol("features") + .setOutputCol("featuresDCT") + .setInverse(false) + + val dctDf = dct.transform(df) + dctDf.select("featuresDCT").show(3) + // $example off$ + sc.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala new file mode 100644 index 000000000000..872de51dc75d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.ElementwiseProduct +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object ElementwiseProductExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("ElementwiseProductExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Create some vector data; also works for sparse vectors + val dataFrame = sqlContext.createDataFrame(Seq( + ("a", Vectors.dense(1.0, 2.0, 3.0)), + ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") + + val transformingVector = Vectors.dense(0.0, 1.0, 2.0) + val transformer = new ElementwiseProduct() + .setScalingVec(transformingVector) + .setInputCol("vector") + .setOutputCol("transformedVector") + + // Batch transform the vectors to create new column: + transformer.transform(dataFrame).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala new file mode 100644 index 000000000000..fb7f28c9886b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.MinMaxScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object MinMaxScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MinMaxScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + + // Compute summary statistics and generate MinMaxScalerModel + val scalerModel = scaler.fit(dataFrame) + + // rescale each feature to range [min, max]. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala new file mode 100644 index 000000000000..8a85f71b56f3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.NGram +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NGramExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NGramExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val wordDataFrame = sqlContext.createDataFrame(Seq( + (0, Array("Hi", "I", "heard", "about", "Spark")), + (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), + (2, Array("Logistic", "regression", "models", "are", "neat")) + )).toDF("label", "words") + + val ngram = new NGram().setInputCol("words").setOutputCol("ngrams") + val ngramDataFrame = ngram.transform(wordDataFrame) + ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala new file mode 100644 index 000000000000..1990b55e8c5e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Normalizer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object NormalizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NormalizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Normalize each Vector using $L^1$ norm. + val normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures") + .setP(1.0) + + val l1NormData = normalizer.transform(dataFrame) + l1NormData.show() + + // Normalize each Vector using $L^\infty$ norm. + val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) + lInfNormData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala new file mode 100644 index 000000000000..66602e211850 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object OneHotEncoderExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("OneHotEncoderExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame(Seq( + (0, "a"), + (1, "b"), + (2, "c"), + (3, "a"), + (4, "a"), + (5, "c") + )).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + .fit(df) + val indexed = indexer.transform(df) + + val encoder = new OneHotEncoder() + .setInputCol("categoryIndex") + .setOutputCol("categoryVec") + val encoded = encoder.transform(indexed) + encoded.select("id", "categoryVec").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala new file mode 100644 index 000000000000..4c806f71a32c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PCA +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PCAExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PCAExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pcaFeatures") + .setK(3) + .fit(df) + val pcaDF = pca.transform(df) + val result = pcaDF.select("pcaFeatures") + result.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala new file mode 100644 index 000000000000..39fb79af3576 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.PolynomialExpansion +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object PolynomialExpansionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("PolynomialExpansionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array( + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0), + Vectors.dense(0.6, -1.1) + ) + val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + val polyDF = polynomialExpansion.transform(df) + polyDF.select("polyFeatures").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala new file mode 100644 index 000000000000..286866edea50 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.RFormula +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RFormulaExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RFormulaExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) + )).toDF("id", "country", "hour", "clicked") + val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") + val output = formula.fit(dataset).transform(dataset) + output.select("features", "label").show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala new file mode 100644 index 000000000000..e0a41e383a7e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StandardScaler +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StandardScalerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StandardScalerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false) + + // Compute summary statistics by fitting the StandardScaler. + val scalerModel = scaler.fit(dataFrame) + + // Normalize each feature to have unit standard deviation. + val scaledData = scalerModel.transform(dataFrame) + scaledData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala new file mode 100644 index 000000000000..655ffce08d3a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StopWordsRemover +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StopWordsRemoverExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StopWordsRemoverExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + + val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) + )).toDF("id", "raw") + + remover.transform(dataSet).show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala new file mode 100644 index 000000000000..9fa494cd2473 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.StringIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object StringIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("StringIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + ).toDF("id", "category") + + val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") + + val indexed = indexer.fit(df).transform(df) + indexed.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala new file mode 100644 index 000000000000..01e0d1388a2f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object TokenizerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TokenizerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val sentenceDataFrame = sqlContext.createDataFrame(Seq( + (0, "Hi I heard about Spark"), + (1, "I wish Java could use case classes"), + (2, "Logistic,regression,models,are,neat") + )).toDF("label", "sentence") + + val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words") + val regexTokenizer = new RegexTokenizer() + .setInputCol("sentence") + .setOutputCol("words") + .setPattern("\\W") // alternatively .setPattern("\\w+").setGaps(false) + + val tokenized = tokenizer.transform(sentenceDataFrame) + tokenized.select("words", "label").take(3).foreach(println) + val regexTokenized = regexTokenizer.transform(sentenceDataFrame) + regexTokenized.select("words", "label").take(3).foreach(println) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala new file mode 100644 index 000000000000..d527924419f8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorAssemblerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorAssemblerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val dataset = sqlContext.createDataFrame( + Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + val output = assembler.transform(dataset) + println(output.select("features", "clicked").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala new file mode 100644 index 000000000000..685891c164e7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.VectorIndexer +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorIndexerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorIndexerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + val indexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexed") + .setMaxCategories(10) + + val indexerModel = indexer.fit(data) + + val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet + println(s"Chose ${categoricalFeatures.size} categorical features: " + + categoricalFeatures.mkString(", ")) + + // Create new column "indexed" with categorical values transformed to indices + val indexedData = indexerModel.transform(data) + indexedData.show() + // $example off$ + sc.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala new file mode 100644 index 000000000000..04f19829eff8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object VectorSlicerExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("VectorSlicerExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) + + val defaultAttr = NumericAttribute.defaultAttr + val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) + val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + + val dataRDD = sc.parallelize(data) + val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) + + val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + + slicer.setIndices(Array(1)).setNames(Array("f3")) + // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + + val output = slicer.transform(dataset) + println(output.select("userFeatures", "features").first()) + // $example off$ + sc.stop() + } +} +// scalastyle:on println From 7a8e587dc04c2fabc875d1754eae7f85b4fba6ba Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 9 Dec 2015 17:16:01 -0800 Subject: [PATCH 1720/1824] [SPARK-12211][DOC][GRAPHX] Fix version number in graphx doc for migration from 1.1 Migration from 1.1 section added to the GraphX doc in 1.2.0 (see https://spark.apache.org/docs/1.2.0/graphx-programming-guide.html#migrating-from-spark-11) uses \{{site.SPARK_VERSION}} as the version where changes were introduced, it should be just 1.2. Author: Andrew Ray Closes #10206 from aray/graphx-doc-1.1-migration. --- docs/graphx-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 6a512ab234bb..9dea9b5904d2 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -70,7 +70,7 @@ operators (e.g., [subgraph](#structural_operators), [joinVertices](#join_operato ## Migrating from Spark 1.1 -GraphX in Spark {{site.SPARK_VERSION}} contains a few user facing API changes: +GraphX in Spark 1.2 contains a few user facing API changes: 1. To improve performance we have introduced a new version of [`mapReduceTriplets`][Graph.mapReduceTriplets] called From 8770bd1213f9b1051dabde9c5424ae7b32143a44 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 9 Dec 2015 17:24:04 -0800 Subject: [PATCH 1721/1824] [SPARK-12165][ADDENDUM] Fix outdated comments on unroll test JoshRosen Author: Andrew Or Closes #10229 from andrewor14/unroll-test-comments. --- .../spark/memory/StaticMemoryManagerSuite.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 6700b94f0f57..272253bc94e9 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -163,15 +163,20 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { assertEvictBlocksToFreeSpaceNotCalled(ms) assert(mm.storageMemoryUsed === 860L) // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes. - // Since we already occupy 60 bytes, we will try to evict only 400 - 60 = 340 bytes. + // As of this point, cache memory is 800 bytes and current unroll memory is 60 bytes. + // Requesting 240 more bytes of unroll memory will leave our total unroll memory at + // 300 bytes, still under the 400-byte limit. Therefore, all 240 bytes are granted. assert(mm.acquireUnrollMemory(dummyBlock, 240L, evictedBlocks)) - assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 860 + 240 - 1000 when(ms.currentUnrollMemory).thenReturn(300L) // 60 + 240 assert(mm.storageMemoryUsed === 1000L) evictedBlocks.clear() + // We already have 300 bytes of unroll memory, so requesting 150 more will leave us + // above the 400-byte limit. Since there is not enough free memory, this request will + // fail even after evicting as much as we can (400 - 300 = 100 bytes). assert(!mm.acquireUnrollMemory(dummyBlock, 150L, evictedBlocks)) - assertEvictBlocksToFreeSpaceCalled(ms, 100L) // 400 - 300 - assert(mm.storageMemoryUsed === 900L) // 100 bytes were evicted + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.storageMemoryUsed === 900L) // Release beyond what was acquired mm.releaseUnrollMemory(maxStorageMem) assert(mm.storageMemoryUsed === 0L) From ac8cdf1cdc148bd21290ecf4d4f9874f8c87cc14 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 9 Dec 2015 18:09:36 -0800 Subject: [PATCH 1722/1824] [SPARK-11678][SQL][DOCS] Document basePath in the programming guide. This PR adds document for `basePath`, which is a new parameter used by `HadoopFsRelation`. The compiled doc is shown below. ![image](https://cloud.githubusercontent.com/assets/2072857/11673132/1ba01192-9dcb-11e5-98d9-ac0b4e92e98c.png) JIRA: https://issues.apache.org/jira/browse/SPARK-11678 Author: Yin Huai Closes #10211 from yhuai/basePathDoc. --- docs/sql-programming-guide.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9f87accd30f4..3f9a831eddc8 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1233,6 +1233,13 @@ infer the data types of the partitioning columns. For these use cases, the autom can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to `true`. When type inference is disabled, string type will be used for the partitioning columns. +Starting from Spark 1.6.0, partition discovery only finds partitions under the given paths +by default. For the above example, if users pass `path/to/table/gender=male` to either +`SQLContext.read.parquet` or `SQLContext.read.load`, `gender` will not be considered as a +partitioning column. If users need to specify the base path that partition discovery +should start with, they can set `basePath` in the data source options. For example, +when `path/to/table/gender=male` is the path of the data and +users set `basePath` to `path/to/table/`, `gender` will be a partitioning column. ### Schema Merging From 2166c2a75083c2262e071a652dd52b1a33348b6e Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Wed, 9 Dec 2015 18:37:35 -0800 Subject: [PATCH 1723/1824] [SPARK-11796] Fix httpclient and httpcore depedency issues related to docker-client This commit fixes dependency issues which prevented the Docker-based JDBC integration tests from running in the Maven build. Author: Mark Grover Closes #9876 from markgrover/master_docker. --- docker-integration-tests/pom.xml | 22 ++++++++++++++++++++++ pom.xml | 28 ++++++++++++++++++++++++++++ sql/core/pom.xml | 2 -- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml index dee0c4aa37ae..39d3f344615e 100644 --- a/docker-integration-tests/pom.xml +++ b/docker-integration-tests/pom.xml @@ -71,6 +71,18 @@
    + + org.apache.httpcomponents + httpclient + 4.5 + test + + + org.apache.httpcomponents + httpcore + 4.4.1 + test + com.google.guava @@ -109,6 +121,16 @@ ${project.version} test + + mysql + mysql-connector-java + test + + + org.postgresql + postgresql + test + prob=$prob, prediction=$prediction") - } - -{% endhighlight %} - - -
    -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// Fit the pipeline to training documents. -PipelineModel model = pipeline.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. -DataFrame predictions = model.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} -
    - -
    -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import LogisticRegression -from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row - -# Prepare training documents from a list of (id, text, label) tuples. -LabeledDocument = Row("id", "text", "label") -training = sqlContext.createDataFrame([ - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) - -# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. -tokenizer = Tokenizer(inputCol="text", outputCol="words") -hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") -lr = LogisticRegression(maxIter=10, regParam=0.01) -pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) - -# Fit the pipeline to training documents. -model = pipeline.fit(training) - -# Prepare test documents, which are unlabeled (id, text) tuples. -test = sqlContext.createDataFrame([ - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")], ["id", "text"]) - -# Make predictions on test documents and print columns of interest. -prediction = model.transform(test) -selected = prediction.select("id", "text", "prediction") -for row in selected.collect(): - print(row) - -{% endhighlight %} -
    - - - -## Example: model selection via cross-validation - -An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. -`Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. - -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). -`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. -`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. - -The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) -for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` -method in each of these evaluators. - -The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. -`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. - -The following example demonstrates using `CrossValidator` to select from a grid of parameters. -To help construct the parameter grid, we use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. - -Note that cross-validation over a grid of parameters is expensive. -E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds. This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained. -In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common). -In other words, using `CrossValidator` can be very expensive. -However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning. - -
    - -
    -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.Row - -// Prepare training data from a list of (id, text, label) tuples. -val training = sqlContext.createDataFrame(Seq( - (0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0), - (4L, "b spark who", 1.0), - (5L, "g d a y", 0.0), - (6L, "spark fly", 1.0), - (7L, "was mapreduce", 0.0), - (8L, "e spark program", 1.0), - (9L, "a e c l", 0.0), - (10L, "spark compile", 1.0), - (11L, "hadoop software", 0.0) -)).toDF("id", "text", "label") - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -val tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words") -val hashingTF = new HashingTF() - .setInputCol(tokenizer.getOutputCol) - .setOutputCol("features") -val lr = new LogisticRegression() - .setMaxIter(10) -val pipeline = new Pipeline() - .setStages(Array(tokenizer, hashingTF, lr)) - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -val paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) - .addGrid(lr.regParam, Array(0.1, 0.01)) - .build() - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -val cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2) // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -val cvModel = cv.fit(training) - -// Prepare test documents, which are unlabeled (id, text) tuples. -val test = sqlContext.createDataFrame(Seq( - (4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop") -)).toDF("id", "text") - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test) - .select("id", "text", "probability", "prediction") - .collect() - .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") - } - -{% endhighlight %} -
    - -
    -{% highlight java %} -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.tuning.CrossValidator; -import org.apache.spark.ml.tuning.CrossValidatorModel; -import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from Java Beans. -public class Document implements Serializable { - private long id; - private String text; - - public Document(long id, String text) { - this.id = id; - this.text = text; - } - - public long getId() { return this.id; } - public void setId(long id) { this.id = id; } - - public String getText() { return this.text; } - public void setText(String text) { this.text = text; } -} - -public class LabeledDocument extends Document implements Serializable { - private double label; - - public LabeledDocument(long id, String text, double label) { - super(id, text); - this.label = label; - } - - public double getLabel() { return this.label; } - public void setLabel(double label) { this.label = label; } -} - - -// Prepare training documents, which are labeled. -DataFrame training = sqlContext.createDataFrame(Arrays.asList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0), - new LabeledDocument(4L, "b spark who", 1.0), - new LabeledDocument(5L, "g d a y", 0.0), - new LabeledDocument(6L, "spark fly", 1.0), - new LabeledDocument(7L, "was mapreduce", 0.0), - new LabeledDocument(8L, "e spark program", 1.0), - new LabeledDocument(9L, "a e c l", 0.0), - new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0) -), LabeledDocument.class); - -// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. -Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); -HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); -LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, -// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) - .addGrid(lr.regParam(), new double[]{0.1, 0.01}) - .build(); - -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric -// is areaUnderROC. -CrossValidator cv = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setNumFolds(2); // Use 3+ in practice - -// Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = cv.fit(training); - -// Prepare test documents, which are unlabeled. -DataFrame test = sqlContext.createDataFrame(Arrays.asList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop") -), Document.class); - -// Make predictions on test documents. cvModel uses the best model found (lrModel). -DataFrame predictions = cvModel.transform(test); -for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); -} - -{% endhighlight %} -
    - -
    - -## Example: model selection via train validation split -In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. -`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in - case of `CrossValidator`. It is therefore less expensive, - but will not produce as reliable results when the training dataset is not sufficiently large. - -`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, -and an `Evaluator`. -It begins by splitting the dataset into two parts using `trainRatio` parameter -which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), -`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. -Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. -For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. -The `ParamMap` which produces the best evaluation metric is selected as the best option. -`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. - -
    - -
    -{% highlight scala %} -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} - -// Prepare training and test data. -val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") -val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) - -val lr = new LinearRegression() - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -val paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam, Array(0.1, 0.01)) - .addGrid(lr.fitIntercept) - .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) - .build() - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -val trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator) - .setEstimatorParamMaps(paramGrid) - // 80% of the data will be used for training and the remaining 20% for validation. - .setTrainRatio(0.8) - -// Run train validation split, and choose the best set of parameters. -val model = trainValidationSplit.fit(training) - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show() - -{% endhighlight %} -
    - -
    -{% highlight java %} -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.tuning.*; -import org.apache.spark.sql.DataFrame; - -DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - -// Prepare training and test data. -DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); -DataFrame training = splits[0]; -DataFrame test = splits[1]; - -LinearRegression lr = new LinearRegression(); - -// We use a ParamGridBuilder to construct a grid of parameters to search over. -// TrainValidationSplit will try all combinations of values and determine best model using -// the evaluator. -ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.1, 0.01}) - .addGrid(lr.fitIntercept()) - .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) - .build(); - -// In this case the estimator is simply the linear regression. -// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -TrainValidationSplit trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid) - .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation - -// Run train validation split, and choose the best set of parameters. -TrainValidationSplitModel model = trainValidationSplit.fit(training); - -// Make predictions on test data. model is the model with combination of parameters -// that performed best. -model.transform(test) - .select("features", "label", "prediction") - .show(); - -{% endhighlight %} -
    - -
    \ No newline at end of file diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 0c13d7d0c82b..a8754835cab9 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -1,148 +1,8 @@ --- layout: global -title: Linear Methods - ML -displayTitle: ML - Linear Methods +title: Linear methods - spark.ml +displayTitle: Linear methods - spark.ml --- - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -In MLlib, we implement popular linear methods such as logistic -regression and linear least squares with $L_1$ or $L_2$ regularization. -Refer to [the linear methods in mllib](mllib-linear-methods.html) for -details. In `spark.ml`, we also include Pipelines API for [Elastic -net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid -of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization -and variable selection via the elastic -net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). -Mathematically, it is defined as a convex combination of the $L_1$ and -the $L_2$ regularization terms: -`\[ -\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 -\]` -By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ -regularization as special cases. For example, if a [linear -regression](https://en.wikipedia.org/wiki/Linear_regression) model is -trained with the elastic net parameter $\alpha$ set to $1$, it is -equivalent to a -[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. -On the other hand, if $\alpha$ is set to $0$, the trained model reduces -to a [ridge -regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. -We implement Pipelines API for both linear regression and logistic -regression with elastic net regularization. - -## Example: Logistic Regression - -The following example shows how to train a logistic regression model -with elastic net regularization. `elasticNetParam` corresponds to -$\alpha$ and `regParam` corresponds to $\lambda$. - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java %} -
    - -
    -{% include_example python/ml/logistic_regression_with_elastic_net.py %} -
    - -
    - -The `spark.ml` implementation of logistic regression also supports -extracting a summary of the model over the training set. Note that the -predictions and metrics which are stored as `Dataframe` in -`BinaryLogisticRegressionSummary` are annotated `@transient` and hence -only available on the driver. - -
    - -
    - -[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) -provides a summary for a -[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). -This will likely change when multiclass classification is supported. - -Continuing the earlier example: - -{% include_example scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala %} -
    - -
    -[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) -provides a summary for a -[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -This will likely change when multiclass classification is supported. - -Continuing the earlier example: - -{% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} -
    - - -
    -Logistic regression model summary is not yet supported in Python. -
    - -
    - -## Example: Linear Regression - -The interface for working with linear regression models and model -summaries is similar to the logistic regression case. The following -example demonstrates training an elastic net regularized linear -regression model and extracting model summary statistics. - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java %} -
    - -
    - -{% include_example python/ml/linear_regression_with_elastic_net.py %} -
    - -
    - -# Optimization - -The optimization algorithm underlying the implementation is called -[Orthant-Wise Limited-memory -QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) -(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 -regularization and elastic net. - + > This section has been moved into the + [classification and regression section](ml-classification-regression.html). diff --git a/docs/ml-survival-regression.md b/docs/ml-survival-regression.md index ab275213b9a8..856ceb2f4e7f 100644 --- a/docs/ml-survival-regression.md +++ b/docs/ml-survival-regression.md @@ -1,96 +1,8 @@ --- layout: global -title: Survival Regression - ML -displayTitle: ML - Survival Regression +title: Survival Regression - spark.ml +displayTitle: Survival Regression - spark.ml --- - -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` - - -In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) -model which is a parametric survival regression model for censored data. -It describes a model for the log of survival time, so it's often called -log-linear model for survival analysis. Different from -[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model -designed for the same purpose, the AFT model is more easily to parallelize -because each instance contribute to the objective function independently. - -Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of -subjects i = 1, ..., n, with possible right-censoring, -the likelihood function under the AFT model is given as: -`\[ -L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} -\]` -Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. -Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function -assumes the form: -`\[ -\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] -\]` -Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, -and $f_{0}(\epsilon_{i})$ is corresponding density function. - -The most commonly used AFT model is based on the Weibull distribution of the survival time. -The Weibull distribution for lifetime corresponding to extreme value distribution for -log of the lifetime, and the $S_{0}(\epsilon)$ function is: -`\[ -S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) -\]` -the $f_{0}(\epsilon_{i})$ function is: -`\[ -f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) -\]` -The log-likelihood function for AFT model with Weibull distribution of lifetime is: -`\[ -\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] -\]` -Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, -the loss function we use to optimize is $-\iota(\beta,\sigma)$. -The gradient functions for $\beta$ and $\log\sigma$ respectively are: -`\[ -\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} -\]` -`\[ -\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] -\]` - -The AFT model can be formulated as a convex optimization problem, -i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ -that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. -The optimization algorithm underlying the implementation is L-BFGS. -The implementation matches the result from R's survival function -[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) - -## Example: - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} -
    - -
    -{% include_example python/ml/aft_survival_regression.py %} -
    - -
    \ No newline at end of file + > This section has been moved into the + [classification and regression section](ml-classification-regression.html#survival-regression). diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 0210950b8990..aaf8bd465c9a 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -1,10 +1,10 @@ --- layout: global -title: Classification and Regression - MLlib -displayTitle: MLlib - Classification and Regression +title: Classification and Regression - spark.mllib +displayTitle: Classification and Regression - spark.mllib --- -MLlib supports various methods for +The `spark.mllib` package supports various methods for [binary classification](http://en.wikipedia.org/wiki/Binary_classification), [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification), and diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 8fbced6c87d9..48d64cd402b1 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -1,7 +1,7 @@ --- layout: global -title: Clustering - MLlib -displayTitle: MLlib - Clustering +title: Clustering - spark.mllib +displayTitle: Clustering - spark.mllib --- [Clustering](https://en.wikipedia.org/wiki/Cluster_analysis) is an unsupervised learning problem whereby we aim to group subsets @@ -10,19 +10,19 @@ often used for exploratory analysis and/or as a component of a hierarchical [supervised learning](https://en.wikipedia.org/wiki/Supervised_learning) pipeline (in which distinct classifiers or regression models are trained for each cluster). -MLlib supports the following models: +The `spark.mllib` package supports the following models: * Table of contents {:toc} ## K-means -[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +[K-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the most commonly used clustering algorithms that clusters the data points into a -predefined number of clusters. The MLlib implementation includes a parallelized +predefined number of clusters. The `spark.mllib` implementation includes a parallelized variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -The implementation in MLlib has the following parameters: +The implementation in `spark.mllib` has the following parameters: * *k* is the number of desired clusters. * *maxIterations* is the maximum number of iterations to run. @@ -171,7 +171,7 @@ sameModel = KMeansModel.load(sc, "myModelPath") A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, -each with its own probability. The MLlib implementation uses the +each with its own probability. The `spark.mllib` implementation uses the [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) algorithm to induce the maximum-likelihood model given a set of samples. The implementation has the following parameters: @@ -308,13 +308,13 @@ graph given pairwise similarties as edge properties, described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf). It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via [power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices. -MLlib includes an implementation of PIC using GraphX as its backend. +`spark.mllib` includes an implementation of PIC using GraphX as its backend. It takes an `RDD` of `(srcId, dstId, similarity)` tuples and outputs a model with the clustering assignments. The similarities must be nonnegative. PIC assumes that the similarity measure is symmetric. A pair `(srcId, dstId)` regardless of the ordering should appear at most once in the input data. If a pair is missing from input, their similarity is treated as zero. -MLlib's PIC implementation takes the following (hyper-)parameters: +`spark.mllib`'s PIC implementation takes the following (hyper-)parameters: * `k`: number of clusters * `maxIterations`: maximum number of power iterations @@ -323,7 +323,7 @@ MLlib's PIC implementation takes the following (hyper-)parameters: **Examples** -In the following, we show code snippets to demonstrate how to use PIC in MLlib. +In the following, we show code snippets to demonstrate how to use PIC in `spark.mllib`.
    @@ -493,7 +493,7 @@ checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. -All of MLlib's LDA models support: +All of `spark.mllib`'s LDA models support: * `describeTopics`: Returns topics as arrays of most important terms and term weights @@ -721,7 +721,7 @@ sameModel = LDAModel.load(sc, "myModelPath") ## Streaming k-means When data arrive in a stream, we may want to estimate clusters dynamically, -updating them as new data arrive. MLlib provides support for streaming k-means clustering, +updating them as new data arrive. `spark.mllib` provides support for streaming k-means clustering, with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign all points to their nearest cluster, compute new cluster centers, then update each cluster using: diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 7cd1b894e7cb..1ebb4654aef1 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -1,7 +1,7 @@ --- layout: global -title: Collaborative Filtering - MLlib -displayTitle: MLlib - Collaborative Filtering +title: Collaborative Filtering - spark.mllib +displayTitle: Collaborative Filtering - spark.mllib --- * Table of contents @@ -11,12 +11,12 @@ displayTitle: MLlib - Collaborative Filtering [Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) is commonly used for recommender systems. These techniques aim to fill in the -missing entries of a user-item association matrix. MLlib currently supports +missing entries of a user-item association matrix. `spark.mllib` currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. -MLlib uses the [alternating least squares +`spark.mllib` uses the [alternating least squares (ALS)](http://dl.acm.org/citation.cfm?id=1608614) -algorithm to learn these latent factors. The implementation in MLlib has the +algorithm to learn these latent factors. The implementation in `spark.mllib` has the following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). @@ -34,7 +34,7 @@ The standard approach to matrix factorization based collaborative filtering trea the entries in the user-item matrix as *explicit* preferences given by the user to the item. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, -clicks, purchases, likes, shares etc.). The approach used in MLlib to deal with such data is taken +clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). Essentially instead of trying to model the matrix of ratings directly, this approach treats the data @@ -119,4 +119,4 @@ a dependency. ## Tutorial The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for -[personalized movie recommendation with MLlib](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). +[personalized movie recommendation with `spark.mllib`](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 3c0c0479674d..363dc7c13b30 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -1,7 +1,7 @@ --- layout: global title: Data Types - MLlib -displayTitle: MLlib - Data Types +displayTitle: Data Types - MLlib --- * Table of contents diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 77ce34e91af3..a8612b6c84fe 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Trees - MLlib -displayTitle: MLlib - Decision Trees +title: Decision Trees - spark.mllib +displayTitle: Decision Trees - spark.mllib --- * Table of contents @@ -15,7 +15,7 @@ feature scaling, and are able to capture non-linearities and feature interaction algorithms such as random forests and boosting are among the top performers for classification and regression tasks. -MLlib supports decision trees for binary and multiclass classification and for regression, +`spark.mllib` supports decision trees for binary and multiclass classification and for regression, using both continuous and categorical features. The implementation partitions data by rows, allowing distributed training with millions of instances. diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index ac3526908a9f..11d8e0bd1d23 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -1,7 +1,7 @@ --- layout: global -title: Dimensionality Reduction - MLlib -displayTitle: MLlib - Dimensionality Reduction +title: Dimensionality Reduction - spark.mllib +displayTitle: Dimensionality Reduction - spark.mllib --- * Table of contents @@ -11,7 +11,7 @@ displayTitle: MLlib - Dimensionality Reduction of reducing the number of variables under consideration. It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -MLlib provides support for dimensionality reduction on the RowMatrix class. +`spark.mllib` provides support for dimensionality reduction on the RowMatrix class. ## Singular value decomposition (SVD) @@ -57,7 +57,7 @@ passes, $O(n)$ storage on each executor, and $O(n k)$ storage on the driver. ### SVD Example -MLlib provides SVD functionality to row-oriented matrices, provided in the +`spark.mllib` provides SVD functionality to row-oriented matrices, provided in the RowMatrix class.
    @@ -141,7 +141,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. +`spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors.
    diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 50450e05d2ab..2416b6fa0aeb 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -1,7 +1,7 @@ --- layout: global -title: Ensembles - MLlib -displayTitle: MLlib - Ensembles +title: Ensembles - spark.mllib +displayTitle: Ensembles - spark.mllib --- * Table of contents @@ -9,7 +9,7 @@ displayTitle: MLlib - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -MLlib supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). +`spark.mllib` supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). Both use [decision trees](mllib-decision-tree.html) as their base models. ## Gradient-Boosted Trees vs. Random Forests @@ -33,9 +33,9 @@ Like decision trees, random forests handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports random forests for binary and multiclass classification and for regression, +`spark.mllib` supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. -MLlib implements random forests using the existing [decision tree](mllib-decision-tree.html) +`spark.mllib` implements random forests using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. ### Basic algorithm @@ -155,9 +155,9 @@ Like decision trees, GBTs handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. -MLlib supports GBTs for binary classification and for regression, +`spark.mllib` supports GBTs for binary classification and for regression, using both continuous and categorical features. -MLlib implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. +`spark.mllib` implements GBTs using the existing [decision tree](mllib-decision-tree.html) implementation. Please see the decision tree guide for more information on trees. *Note*: GBTs do not yet support multiclass classification. For multiclass problems, please use [decision trees](mllib-decision-tree.html) or [Random Forests](mllib-ensembles.html#Random-Forest). @@ -171,7 +171,7 @@ The specific mechanism for re-labeling instances is defined by a loss function ( #### Losses -The table below lists the losses currently supported by GBTs in MLlib. +The table below lists the losses currently supported by GBTs in `spark.mllib`. Note that each loss is applicable to one of classification or regression, not both. Notation: $N$ = number of instances. $y_i$ = label of instance $i$. $x_i$ = features of instance $i$. $F(x_i)$ = model's predicted label for instance $i$. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 6924037b941f..774826c2703f 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -1,20 +1,20 @@ --- layout: global -title: Evaluation Metrics - MLlib -displayTitle: MLlib - Evaluation Metrics +title: Evaluation Metrics - spark.mllib +displayTitle: Evaluation Metrics - spark.mllib --- * Table of contents {:toc} -Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +`spark.mllib` comes with a number of machine learning algorithms that can be used to learn from and make predictions on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance -of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +of the model on some criteria, which depends on the application and its requirements. `spark.mllib` also provides a suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those -metrics that are currently available in Spark's MLlib are detailed in this section. +metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 5bee170c61fe..7796bac69756 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction and Transformation - MLlib -displayTitle: MLlib - Feature Extraction and Transformation +title: Feature Extraction and Transformation - spark.mllib +displayTitle: Feature Extraction and Transformation - spark.mllib --- * Table of contents @@ -31,7 +31,7 @@ The TF-IDF measure is simply the product of TF and IDF: TFIDF(t, d, D) = TF(t, d) \cdot IDF(t, D). \]` There are several variants on the definition of term frequency and document frequency. -In MLlib, we separate TF and IDF to make them flexible. +In `spark.mllib`, we separate TF and IDF to make them flexible. Our implementation of term frequency utilizes the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing). @@ -44,7 +44,7 @@ To reduce the chance of collision, we can increase the target feature dimension, the number of buckets of the hash table. The default feature dimension is `$2^{20} = 1,048,576$`. -**Note:** MLlib doesn't provide tools for text segmentation. +**Note:** `spark.mllib` doesn't provide tools for text segmentation. We refer users to the [Stanford NLP Group](http://nlp.stanford.edu/) and [scalanlp/chalk](https://github.com/scalanlp/chalk). @@ -86,7 +86,7 @@ val idf = new IDF().fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} -MLlib's IDF implementation provides an option for ignoring terms which occur in less than a +`spark.mllib`'s IDF implementation provides an option for ignoring terms which occur in less than a minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature can be used by passing the `minDocFreq` value to the IDF constructor. @@ -134,7 +134,7 @@ idf = IDF().fit(tf) tfidf = idf.transform(tf) {% endhighlight %} -MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +`spark.mllib`'s IDF implementation provides an option for ignoring terms which occur in less than a minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature can be used by passing the `minDocFreq` value to the IDF constructor. diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index fe42896a05d8..2c8a8f236163 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -1,7 +1,7 @@ --- layout: global -title: Frequent Pattern Mining - MLlib -displayTitle: MLlib - Frequent Pattern Mining +title: Frequent Pattern Mining - spark.mllib +displayTitle: Frequent Pattern Mining - spark.mllib --- Mining frequent items, itemsets, subsequences, or other substructures is usually among the @@ -9,7 +9,7 @@ first steps to analyze a large-scale dataset, which has been an active research data mining for years. We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) for more information. -MLlib provides a parallel implementation of FP-growth, +`spark.mllib` provides a parallel implementation of FP-growth, a popular algorithm to mining frequent itemsets. ## FP-growth @@ -22,13 +22,13 @@ Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) al the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. -In MLlib, we implemented a parallel version of FP-growth called PFP, +In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). PFP distributes the work of growing FP-trees based on the suffices of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. -MLlib's FP-growth implementation takes the following (hyper-)parameters: +`spark.mllib`'s FP-growth implementation takes the following (hyper-)parameters: * `minSupport`: the minimum support for an itemset to be identified as frequent. For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. @@ -126,7 +126,7 @@ PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer the reader to the referenced paper for formalizing the sequential pattern mining problem. -MLlib's PrefixSpan implementation takes the following parameters: +`spark.mllib`'s PrefixSpan implementation takes the following parameters: * `minSupport`: the minimum support required to be considered a frequent sequential pattern. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 3bc2b780601c..7fef6b5c61f9 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -66,7 +66,7 @@ We list major functionality from both below, with links to detailed guides. # spark.ml: high-level APIs for ML pipelines -* [Overview: estimators, transformers and pipelines](ml-intro.html) +* [Overview: estimators, transformers and pipelines](ml-guide.html) * [Extracting, transforming and selecting features](ml-features.html) * [Classification and regression](ml-classification-regression.html) * [Clustering](ml-clustering.html) diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 85f9226b4341..8ede4407d584 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -1,7 +1,7 @@ --- layout: global -title: Isotonic regression - MLlib -displayTitle: MLlib - Regression +title: Isotonic regression - spark.mllib +displayTitle: Regression - spark.mllib --- ## Isotonic regression @@ -23,7 +23,7 @@ Essentially isotonic regression is a [monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) best fitting the original data points. -MLlib supports a +`spark.mllib` supports a [pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 132f8c354aa9..20b35612cab9 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -1,7 +1,7 @@ --- layout: global -title: Linear Methods - MLlib -displayTitle: MLlib - Linear Methods +title: Linear Methods - spark.mllib +displayTitle: Linear Methods - spark.mllib --- * Table of contents @@ -41,7 +41,7 @@ the objective function is of the form Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and `$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. -Several of MLlib's classification and regression algorithms fall into this category, +Several of `spark.mllib`'s classification and regression algorithms fall into this category, and are discussed here. The objective function `$f$` has two parts: @@ -55,7 +55,7 @@ training error) and minimizing model complexity (i.e., to avoid overfitting). ### Loss functions The following table summarizes the loss functions and their gradients or sub-gradients for the -methods MLlib supports: +methods `spark.mllib` supports: @@ -83,7 +83,7 @@ methods MLlib supports: The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to encourage simple models and avoid overfitting. We support the following -regularizers in MLlib: +regularizers in `spark.mllib`:
    @@ -115,7 +115,10 @@ especially when the number of training examples is small. ### Optimization -Under the hood, linear methods use convex optimization methods to optimize the objective functions. MLlib uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. +Under the hood, linear methods use convex optimization methods to optimize the objective functions. +`spark.mllib` uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). +Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. +Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods. ## Classification @@ -126,16 +129,16 @@ The most common classification type is categories, usually named positive and negative. If there are more than two categories, it is called [multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification). -MLlib supports two linear methods for classification: linear Support Vector Machines (SVMs) +`spark.mllib` supports two linear methods for classification: linear Support Vector Machines (SVMs) and logistic regression. Linear SVMs supports only binary classification, while logistic regression supports both binary and multiclass classification problems. -For both methods, MLlib supports L1 and L2 regularized variants. +For both methods, `spark.mllib` supports L1 and L2 regularized variants. The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either $+1$ (positive) or $-1$ (negative), which is convenient for the formulation. -*However*, the negative label is represented by $0$ in MLlib instead of $-1$, to be consistent with +*However*, the negative label is represented by $0$ in `spark.mllib` instead of $-1$, to be consistent with multiclass labeling. ### Linear Support Vector Machines (SVMs) @@ -207,7 +210,7 @@ val sameModel = SVMModel.load(sc, "myModelPath") The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. @@ -293,7 +296,7 @@ public class SVMClassifier { The `SVMWithSGD.train()` method by default performs L2 regularization with the regularization parameter set to 1.0. If we want to configure this algorithm, we can customize `SVMWithSGD` further by creating a new object directly and -calling setter methods. All other MLlib algorithms support customization in +calling setter methods. All other `spark.mllib` algorithms support customization in this way as well. For example, the following code produces an L1 regularized variant of SVMs with regularization parameter set to 0.1, and runs the training algorithm for 200 iterations. @@ -375,7 +378,7 @@ Binary logistic regression can be generalized into train and predict multiclass classification problems. For example, for $K$ possible outcomes, one of the outcomes can be chosen as a "pivot", and the other $K - 1$ outcomes can be separately regressed against the pivot outcome. -In MLlib, the first class $0$ is chosen as the "pivot" class. +In `spark.mllib`, the first class $0$ is chosen as the "pivot" class. See Section 4.4 of [The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for references. @@ -726,7 +729,7 @@ a dependency. ###Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +updating the parameters of the model as new data arrives. `spark.mllib` currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -852,7 +855,7 @@ will get better! # Implementation (developer) -Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent +Behind the scene, `spark.mllib` implements a simple distributed version of stochastic gradient descent (SGD), building on the underlying gradient descent primitive (as described in the optimization section). All provided algorithms take as input a regularization parameter (`regParam`) along with various parameters associated with stochastic diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 774b85d1f773..73e4fddf67fc 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -1,7 +1,7 @@ --- layout: global -title: Old Migration Guides - MLlib -displayTitle: MLlib - Old Migration Guides +title: Old Migration Guides - spark.mllib +displayTitle: Old Migration Guides - spark.mllib description: MLlib migration guides from before Spark SPARK_VERSION_SHORT --- diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 60ac6c7e5bb1..d0d594af6a4a 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -1,7 +1,7 @@ --- layout: global -title: Naive Bayes - MLlib -displayTitle: MLlib - Naive Bayes +title: Naive Bayes - spark.mllib +displayTitle: Naive Bayes - spark.mllib --- [Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple @@ -12,7 +12,7 @@ distribution of each feature given label, and then it applies Bayes' theorem to compute the conditional probability distribution of label given an observation and use it for prediction. -MLlib supports [multinomial naive +`spark.mllib` supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index ad7bcd9bfd40..f90b66f8e2c4 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -1,7 +1,7 @@ --- layout: global -title: Optimization - MLlib -displayTitle: MLlib - Optimization +title: Optimization - spark.mllib +displayTitle: Optimization - spark.mllib --- * Table of contents @@ -87,7 +87,7 @@ in the `$t$`-th iteration, with the input parameter `$s=$ stepSize`. Note that s step-size for SGD methods can often be delicate in practice and is a topic of active research. **Gradients.** -A table of (sub)gradients of the machine learning methods implemented in MLlib, is available in +A table of (sub)gradients of the machine learning methods implemented in `spark.mllib`, is available in the classification and regression section. @@ -140,7 +140,7 @@ other first-order optimization. ### Choosing an Optimization Method -[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in MLlib support both SGD and L-BFGS. +[Linear methods](mllib-linear-methods.html) use optimization internally, and some linear methods in `spark.mllib` support both SGD and L-BFGS. Different optimization methods can have different convergence guarantees depending on the properties of the objective function, and we cannot cover the literature here. In general, when L-BFGS is available, we recommend using it instead of SGD since L-BFGS tends to converge faster (in fewer iterations). diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index 615287125c03..b532ad907dfc 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -1,21 +1,21 @@ --- layout: global -title: PMML model export - MLlib -displayTitle: MLlib - PMML model export +title: PMML model export - spark.mllib +displayTitle: PMML model export - spark.mllib --- * Table of contents {:toc} -## MLlib supported models +## `spark.mllib` supported models -MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). +`spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). -The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. +The table below outlines the `spark.mllib` models that can be exported to PMML and their equivalent PMML model.
    - + diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index de209f68e19c..652d215fa865 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -1,7 +1,7 @@ --- layout: global -title: Basic Statistics - MLlib -displayTitle: MLlib - Basic Statistics +title: Basic Statistics - spark.mllib +displayTitle: Basic Statistics - spark.mllib --- * Table of contents @@ -112,7 +112,7 @@ print(summary.numNonzeros()) ## Correlations -Calculating the correlation between two series of data is a common operation in Statistics. In MLlib +Calculating the correlation between two series of data is a common operation in Statistics. In `spark.mllib` we provide the flexibility to calculate pairwise correlations among many series. The supported correlation methods are currently Pearson's and Spearman's correlation. @@ -209,7 +209,7 @@ print(Statistics.corr(data, method="pearson")) ## Stratified sampling -Unlike the other statistics functions, which reside in MLlib, stratified sampling methods, +Unlike the other statistics functions, which reside in `spark.mllib`, stratified sampling methods, `sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified sampling, the keys can be thought of as a label and the value as a specific attribute. For example the key can be man or woman, or document ids, and the respective values can be the list of ages @@ -294,12 +294,12 @@ approxSample = data.sampleByKey(False, fractions); ## Hypothesis testing Hypothesis testing is a powerful tool in statistics to determine whether a result is statistically -significant, whether this result occurred by chance or not. MLlib currently supports Pearson's +significant, whether this result occurred by chance or not. `spark.mllib` currently supports Pearson's chi-squared ( $\chi^2$) tests for goodness of fit and independence. The input data types determine whether the goodness of fit or the independence test is conducted. The goodness of fit test requires an input type of `Vector`, whereas the independence test requires a `Matrix` as input. -MLlib also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared +`spark.mllib` also supports the input type `RDD[LabeledPoint]` to enable feature selection via chi-squared independence tests.
    @@ -438,7 +438,7 @@ for i, result in enumerate(featureTestResults):
    -Additionally, MLlib provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test +Additionally, `spark.mllib` provides a 1-sample, 2-sided implementation of the Kolmogorov-Smirnov (KS) test for equality of probability distributions. By providing the name of a theoretical distribution (currently solely supported for the normal distribution) and its parameters, or a function to calculate the cumulative distribution according to a given theoretical distribution, the user can @@ -522,7 +522,7 @@ print(testResult) # summary of the test including the p-value, test statistic, ### Streaming Significance Testing -MLlib provides online implementations of some tests to support use cases +`spark.mllib` provides online implementations of some tests to support use cases like A/B testing. These tests may be performed on a Spark Streaming `DStream[(Boolean,Double)]` where the first element of each tuple indicates control group (`false`) or treatment group (`true`) and the @@ -550,7 +550,7 @@ provides streaming hypothesis testing. ## Random data generation Random data generation is useful for randomized algorithms, prototyping, and performance testing. -MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: +`spark.mllib` supports generating random RDDs with i.i.d. values drawn from a given distribution: uniform, standard normal, or Poisson.
    From 4a46b8859d3314b5b45a67cdc5c81fecb6e9e78c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 10 Dec 2015 13:26:30 -0800 Subject: [PATCH 1736/1824] [SPARK-11563][CORE][REPL] Use RpcEnv to transfer REPL-generated classes. This avoids bringing up yet another HTTP server on the driver, and instead reuses the file server already managed by the driver's RpcEnv. As a bonus, the repl now inherits the security features of the network library. There's also a small change to create the directory for storing classes under the root temp dir for the application (instead of directly under java.io.tmpdir). Author: Marcelo Vanzin Closes #9923 from vanzin/SPARK-11563. --- .../org/apache/spark/HttpFileServer.scala | 5 ++ .../scala/org/apache/spark/HttpServer.scala | 26 ++++++---- .../scala/org/apache/spark/SparkContext.scala | 6 +++ .../org/apache/spark/executor/Executor.scala | 6 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 18 +++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 5 ++ .../spark/rpc/netty/NettyStreamManager.scala | 22 ++++++++- .../org/apache/spark/rpc/RpcEnvSuite.scala | 25 +++++++++- docs/configuration.md | 8 ---- docs/security.md | 8 ---- .../org/apache/spark/repl/SparkILoop.scala | 17 ++++--- .../org/apache/spark/repl/SparkIMain.scala | 28 ++--------- .../scala/org/apache/spark/repl/Main.scala | 23 ++++----- .../spark/repl/ExecutorClassLoader.scala | 36 +++++++------- .../spark/repl/ExecutorClassLoaderSuite.scala | 48 +++++++++++++++---- 15 files changed, 183 insertions(+), 98 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7cf7bc0dc681..77d8ec9bb160 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -71,6 +71,11 @@ private[spark] class HttpFileServer( serverUri + "/jars/" + file.getName } + def addDirectory(path: String, resourceBase: String): String = { + httpServer.addDirectory(path, resourceBase) + serverUri + path + } + def addFileToDir(file: File, dir: File) : String = { // Check whether the file is a directory. If it is, throw a more meaningful exception. // If we don't catch this, Guava throws a very confusing error message: diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 8de3a6c04df3..faa3ef3d7561 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -23,10 +23,9 @@ import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} - import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler} +import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils @@ -52,6 +51,11 @@ private[spark] class HttpServer( private var server: Server = null private var port: Int = requestedPort + private val servlets = { + val handler = new ServletContextHandler() + handler.setContextPath("/") + handler + } def start() { if (server != null) { @@ -65,6 +69,14 @@ private[spark] class HttpServer( } } + def addDirectory(contextPath: String, resourceBase: String): Unit = { + val holder = new ServletHolder() + holder.setInitParameter("resourceBase", resourceBase) + holder.setInitParameter("pathInfoOnly", "true") + holder.setServlet(new DefaultServlet()) + servlets.addServlet(holder, contextPath.stripSuffix("/") + "/*") + } + /** * Actually start the HTTP server on the given port. * @@ -85,21 +97,17 @@ private[spark] class HttpServer( val threadPool = new QueuedThreadPool threadPool.setDaemon(true) server.setThreadPool(threadPool) - val resHandler = new ResourceHandler - resHandler.setResourceBase(resourceBase.getAbsolutePath) - - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + addDirectory("/", resourceBase.getAbsolutePath) if (securityManager.isAuthenticationEnabled()) { logDebug("HttpServer is using security") val sh = setupSecurityHandler(securityManager) // make sure we go through security handler to get resources - sh.setHandler(handlerList) + sh.setHandler(servlets) server.setHandler(sh) } else { logDebug("HttpServer is not using security") - server.setHandler(handlerList) + server.setHandler(servlets) } server.start() diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8a62b71c3fa6..194ecc0a0434 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -457,6 +457,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _env = createSparkEnv(_conf, isLocal, listenerBus) SparkEnv.set(_env) + // If running the REPL, register the repl's output dir with the file server. + _conf.getOption("spark.repl.class.outputDir").foreach { path => + val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path)) + _conf.set("spark.repl.class.uri", replUri) + } + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) _statusTracker = new SparkStatusTracker(this) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 7b68dfe5ad06..552b644d13aa 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -364,9 +364,9 @@ private[spark] class Executor( val _userClassPathFirst: java.lang.Boolean = userClassPathFirst val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], - classOf[ClassLoader], classOf[Boolean]) - constructor.newInstance(conf, classUri, parent, _userClassPathFirst) + val constructor = klass.getConstructor(classOf[SparkConf], classOf[SparkEnv], + classOf[String], classOf[ClassLoader], classOf[Boolean]) + constructor.newInstance(conf, env, classUri, parent, _userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 3d7d281b0dd6..64a4a8bf7c5e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -179,6 +179,24 @@ private[spark] trait RpcEnvFileServer { */ def addJar(file: File): String + /** + * Adds a local directory to be served via this file server. + * + * @param baseUri Leading URI path (files can be retrieved by appending their relative + * path to this base URI). This cannot be "files" nor "jars". + * @param path Path to the local directory. + * @return URI for the root of the directory in the file server. + */ + def addDirectory(baseUri: String, path: File): String + + /** Validates and normalizes the base URI for directories. */ + protected def validateDirectoryUri(baseUri: String): String = { + val fixedBaseUri = "/" + baseUri.stripPrefix("/").stripSuffix("/") + require(fixedBaseUri != "/files" && fixedBaseUri != "/jars", + "Directory URI cannot be /files nor /jars.") + fixedBaseUri + } + } private[spark] case class RpcEnvConfig( diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 94dbec593c31..9d098154f719 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -273,6 +273,11 @@ private[akka] class AkkaFileServer( getFileServer().addJar(file) } + override def addDirectory(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + getFileServer().addDirectory(fixedBaseUri, path.getAbsolutePath()) + } + def shutdown(): Unit = { if (httpFileServer != null) { httpFileServer.stop() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index a2768b4252dc..ecd96972455d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -25,12 +25,22 @@ import org.apache.spark.rpc.RpcEnvFileServer /** * StreamManager implementation for serving files from a NettyRpcEnv. + * + * Three kinds of resources can be registered in this manager, all backed by actual files: + * + * - "/files": a flat list of files; used as the backend for [[SparkContext.addFile]]. + * - "/jars": a flat list of files; used as the backend for [[SparkContext.addJar]]. + * - arbitrary directories; all files under the directory become available through the manager, + * respecting the directory's hierarchy. + * + * Only streaming (openStream) is supported. */ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) extends StreamManager with RpcEnvFileServer { private val files = new ConcurrentHashMap[String, File]() private val jars = new ConcurrentHashMap[String, File]() + private val dirs = new ConcurrentHashMap[String, File]() override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { throw new UnsupportedOperationException() @@ -41,7 +51,10 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) val file = ftype match { case "files" => files.get(fname) case "jars" => jars.get(fname) - case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") + case other => + val dir = dirs.get(ftype) + require(dir != null, s"Invalid stream URI: $ftype not found.") + new File(dir, fname) } require(file != null && file.isFile(), s"File not found: $streamId") @@ -60,4 +73,11 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" } + override def addDirectory(baseUri: String, path: File): String = { + val fixedBaseUri = validateDirectoryUri(baseUri) + require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null, + s"URI '$fixedBaseUri' already registered.") + s"${rpcEnv.address.toSparkURL}$fixedBaseUri" + } + } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6cc958a5f6bc..a61d0479aacd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -734,9 +734,28 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val jar = new File(tempDir, "jar") Files.write(UUID.randomUUID().toString(), jar, UTF_8) + val dir1 = new File(tempDir, "dir1") + assert(dir1.mkdir()) + val subFile1 = new File(dir1, "file1") + Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) + + val dir2 = new File(tempDir, "dir2") + assert(dir2.mkdir()) + val subFile2 = new File(dir2, "file2") + Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) + val fileUri = env.fileServer.addFile(file) val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) + val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) + val dir2Uri = env.fileServer.addDirectory("/dir2", dir2) + + // Try registering directories with invalid names. + Seq("/files", "/jars").foreach { uri => + intercept[IllegalArgumentException] { + env.fileServer.addDirectory(uri, dir1) + } + } val destDir = Utils.createTempDir() val sm = new SecurityManager(conf) @@ -745,7 +764,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val files = Seq( (file, fileUri), (empty, emptyUri), - (jar, jarUri)) + (jar, jarUri), + (subFile1, dir1Uri + "/file1"), + (subFile2, dir2Uri + "/file2")) files.foreach { case (f, uri) => val destFile = new File(destDir, f.getName()) Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) @@ -753,7 +774,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } // Try to download files that do not exist. - Seq("files", "jars").foreach { root => + Seq("files", "jars", "dir1").foreach { root => intercept[Exception] { val uri = env.address.toSparkURL + s"/$root/doesNotExist" Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) diff --git a/docs/configuration.md b/docs/configuration.md index fd61ddc244f4..873a2d0b303c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1053,14 +1053,6 @@ Apart from these, the following properties are also available, and may be useful to port + maxRetries.
    - - - - - diff --git a/docs/security.md b/docs/security.md index e1af221d446b..0bfc791c5744 100644 --- a/docs/security.md +++ b/docs/security.md @@ -169,14 +169,6 @@ configure those ports. - - - - - - - - diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 304b1e8cdbed..22749c460934 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -253,7 +253,7 @@ class SparkILoop( case xs => xs find (_.name == cmd) } } - private var fallbackMode = false + private var fallbackMode = false private def toggleFallbackMode() { val old = fallbackMode @@ -261,9 +261,9 @@ class SparkILoop( System.setProperty("spark.repl.fallback", fallbackMode.toString) echo(s""" |Switched ${if (old) "off" else "on"} fallback mode without restarting. - | If you have defined classes in the repl, it would + | If you have defined classes in the repl, it would |be good to redefine them incase you plan to use them. If you still run - |into issues it would be good to restart the repl and turn on `:fallback` + |into issues it would be good to restart the repl and turn on `:fallback` |mode as first command. """.stripMargin) } @@ -350,7 +350,7 @@ class SparkILoop( shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), nullary("fallback", """ - |disable/enable advanced repl changes, these fix some issues but may introduce others. + |disable/enable advanced repl changes, these fix some issues but may introduce others. |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) @@ -1009,8 +1009,13 @@ class SparkILoop( val conf = new SparkConf() .setMaster(getMaster()) .setJars(jars) - .set("spark.repl.class.uri", intp.classServerUri) .setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -1025,7 +1030,7 @@ class SparkILoop( val loader = Utils.getContextOrSparkClassLoader try { sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) - .newInstance(sparkContext).asInstanceOf[SQLContext] + .newInstance(sparkContext).asInstanceOf[SQLContext] logInfo("Created sql context (with Hive support)..") } catch { diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 829b12269fd2..7fcb423575d3 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -37,7 +37,7 @@ import scala.reflect.{ ClassTag, classTag } import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable -import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils import org.apache.spark.annotation.DeveloperApi @@ -96,10 +96,9 @@ import org.apache.spark.annotation.DeveloperApi private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ - private lazy val outputDir = { - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - Utils.createTempDir(rootDir) + private[repl] val outputDir = { + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + Utils.createTempDir(root = rootDir, namePrefix = "repl") } if (SPARK_DEBUG_REPL) { echo("Output directory: " + outputDir) @@ -114,8 +113,6 @@ import org.apache.spark.annotation.DeveloperApi private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - private val classServerPort = conf.getInt("spark.replClassServer.port", 0) - private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings private var printResults = true // whether to print result lines private var totalSilence = false // whether to print anything @@ -124,22 +121,6 @@ import org.apache.spark.annotation.DeveloperApi private var bindExceptions = true // whether to bind the lastException variable private var _executionWrapper = "" // code to be wrapped around all lines - - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() - if (SPARK_DEBUG_REPL) { - echo("Class server started, URI = " + classServer.uri) - } - - /** - * URI of the class server used to feed REPL compiled classes. - * - * @return The string representing the class server uri - */ - @DeveloperApi - def classServerUri = classServer.uri - /** We're going to go to some trouble to initialize the compiler asynchronously. * It's critical that nothing call into it until it's been initialized or we will * run into unrecoverable issues, but the perceived repl startup time goes @@ -994,7 +975,6 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi def close() { reporter.flush() - classServer.stop() } /** diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 455a6b9a93aa..44650f25f7a1 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -28,11 +28,13 @@ import org.apache.spark.sql.SQLContext object Main extends Logging { val conf = new SparkConf() - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = conf.get("spark.repl.classdir", tmp) - val outputDir = Utils.createTempDir(rootDir) + val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) + val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl") + val s = new Settings() + s.processArguments(List("-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", + "-classpath", getAddedJars.mkString(File.pathSeparator)), true) // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed - lazy val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. @@ -45,7 +47,6 @@ object Main extends Logging { } def main(args: Array[String]) { - val interpArguments = List( "-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", @@ -57,11 +58,7 @@ object Main extends Logging { if (!hasErrors) { if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() interp.process(settings) // Repl starts and goes in loop of R.E.P.L - classServer.stop() Option(sparkContext).map(_.stop) } } @@ -82,9 +79,13 @@ object Main extends Logging { val conf = new SparkConf() .setMaster(getMaster) .setJars(jars) - .set("spark.repl.class.uri", classServer.uri) .setIfMissing("spark.app.name", "Spark shell") - logInfo("Spark class server started at " + classServer.uri) + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + .set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index a8859fcd4584..da8f0aa1e336 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -19,6 +19,7 @@ package org.apache.spark.repl import java.io.{IOException, ByteArrayOutputStream, InputStream} import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.nio.channels.Channels import scala.util.control.NonFatal @@ -38,7 +39,11 @@ import org.apache.spark.util.ParentClassLoader * This class loader delegates getting/finding resources to parent loader, * which makes sense until REPL never provide resource dynamically. */ -class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, +class ExecutorClassLoader( + conf: SparkConf, + env: SparkEnv, + classUri: String, + parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader with Logging { val uri = new URI(classUri) val directory = uri.getPath @@ -48,13 +53,12 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 - // Hadoop FileSystem object for our URI, if it isn't using HTTP - var fileSystem: FileSystem = { - if (Set("http", "https", "ftp").contains(uri.getScheme)) { - null - } else { - FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) - } + private val fetchFn: (String) => InputStream = uri.getScheme() match { + case "spark" => getClassFileInputStreamFromSparkRPC + case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer + case _ => + val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) + getClassFileInputStreamFromFileSystem(fileSystem) } override def getResource(name: String): URL = { @@ -90,6 +94,11 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { + val channel = env.rpcEnv.openChannel(s"$classUri/$path") + Channels.newInputStream(channel) + } + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) @@ -126,7 +135,8 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } - private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( + pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) if (fileSystem.exists(path)) { fileSystem.open(path) @@ -139,13 +149,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader val pathInDirectory = name.replace('.', '/') + ".class" var inputStream: InputStream = null try { - inputStream = { - if (fileSystem != null) { - getClassFileInputStreamFromFileSystem(pathInDirectory) - } else { - getClassFileInputStreamFromHttpServer(pathInDirectory) - } - } + inputStream = fetchFn(pathInDirectory) val bytes = readAndTransformClass(name, inputStream) Some(defineClass(name, bytes, 0, bytes.length)) } catch { diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index c1211f7596b9..1360f09e7fa1 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -18,24 +18,29 @@ package org.apache.spark.repl import java.io.File -import java.net.{URL, URLClassLoader} +import java.net.{URI, URL, URLClassLoader} +import java.nio.channels.{FileChannel, ReadableByteChannel} import java.nio.charset.StandardCharsets +import java.nio.file.{Paths, StandardOpenOption} import java.util -import com.google.common.io.Files - import scala.concurrent.duration._ import scala.io.Source import scala.language.implicitConversions import scala.language.postfixOps +import com.google.common.io.Files import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.mockito.Matchers.anyString import org.mockito.Mockito._ import org.apache.spark._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils class ExecutorClassLoaderSuite @@ -78,7 +83,7 @@ class ExecutorClassLoaderSuite test("child first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "1") @@ -86,7 +91,7 @@ class ExecutorClassLoaderSuite test("parent first") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, false) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false) val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -94,7 +99,7 @@ class ExecutorClassLoaderSuite test("child first can fall back") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") @@ -102,7 +107,7 @@ class ExecutorClassLoaderSuite test("child first can fail") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() } @@ -110,7 +115,7 @@ class ExecutorClassLoaderSuite test("resource from parent") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val resourceName: String = parentResourceNames.head val is = classLoader.getResourceAsStream(resourceName) assert(is != null, s"Resource $resourceName not found") @@ -120,7 +125,7 @@ class ExecutorClassLoaderSuite test("resources from parent") { val parentLoader = new URLClassLoader(urls2, null) - val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true) val resourceName: String = parentResourceNames.head val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) assert(resources.hasMoreElements, s"Resource $resourceName not found") @@ -142,7 +147,7 @@ class ExecutorClassLoaderSuite SparkEnv.set(mockEnv) // Create an ExecutorClassLoader that's configured to load classes from the HTTP server val parentLoader = new URLClassLoader(Array.empty, null) - val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + val classLoader = new ExecutorClassLoader(conf, null, classServer.uri, parentLoader, false) classLoader.httpUrlConnectionTimeoutMillis = 500 // Check that this class loader can actually load classes that exist val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() @@ -177,4 +182,27 @@ class ExecutorClassLoaderSuite failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) } + test("fetch classes using Spark's RpcEnv") { + val env = mock[SparkEnv] + val rpcEnv = mock[RpcEnv] + when(env.rpcEnv).thenReturn(rpcEnv) + when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() { + override def answer(invocation: InvocationOnMock): ReadableByteChannel = { + val uri = new URI(invocation.getArguments()(0).asInstanceOf[String]) + val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/")) + FileChannel.open(path, StandardOpenOption.READ) + } + }) + + val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234", + getClass().getClassLoader(), false) + + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + intercept[java.lang.ClassNotFoundException] { + classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + } + } + } From 6a6c1fc5c807ba4e8aba3e260537aa527ff5d46a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 10 Dec 2015 14:21:15 -0800 Subject: [PATCH 1737/1824] [SPARK-11713] [PYSPARK] [STREAMING] Initial RDD updateStateByKey for PySpark Adding ability to define an initial state RDD for use with updateStateByKey PySpark. Added unit test and changed stateful_network_wordcount example to use initial RDD. Author: Bryan Cutler Closes #10082 from BryanCutler/initial-rdd-updateStateByKey-SPARK-11713. --- .../streaming/stateful_network_wordcount.py | 5 ++++- python/pyspark/streaming/dstream.py | 13 ++++++++++-- python/pyspark/streaming/tests.py | 20 +++++++++++++++++++ .../streaming/api/python/PythonDStream.scala | 14 +++++++++++-- 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 16ef646b7c42..f8bbc659c2ea 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -44,13 +44,16 @@ ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") + # RDD with initial state (key, value) pairs + initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)]) + def updateFunc(new_values, last_sum): return sum(new_values) + (last_sum or 0) lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) running_counts = lines.flatMap(lambda line: line.split(" "))\ .map(lambda word: (word, 1))\ - .updateStateByKey(updateFunc) + .updateStateByKey(updateFunc, initialRDD=initialStateRDD) running_counts.pprint() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index acec850f02c2..f61137cb88c4 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -568,7 +568,7 @@ def invReduceFunc(t, a, b): self._ssc._jduration(slideDuration)) return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) - def updateStateByKey(self, updateFunc, numPartitions=None): + def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): """ Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. @@ -579,6 +579,9 @@ def updateStateByKey(self, updateFunc, numPartitions=None): if numPartitions is None: numPartitions = self._sc.defaultParallelism + if initialRDD and not isinstance(initialRDD, RDD): + initialRDD = self._sc.parallelize(initialRDD) + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) @@ -590,7 +593,13 @@ def reduceFunc(t, a, b): jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) - dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + if initialRDD: + initialRDD = initialRDD._reserialize(self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc, + initialRDD._jrdd) + else: + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a2bfd79e1abc..4949cd68e321 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -403,6 +403,26 @@ def func(dstream): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_update_state_by_key_initial_rdd(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + initial = [('k', [0, 1])] + initial = self.sc.parallelize(initial, 1) + + input = [[('k', i)] for i in range(2, 5)] + + def func(dstream): + return dstream.updateStateByKey(updater, initialRDD=initial) + + expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + def test_failed_func(self): # Test failure in # TransformFunction.apply(rdd: Option[RDD[_]], time: Time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 994309ddd0a3..056248ccc7bc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -264,9 +264,19 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction, + initialRDD: Option[RDD[Array[Byte]]]) extends PythonDStream(parent, reduceFunc) { + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None) + + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction, + initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, Some(initialRDD.rdd)) + super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -274,7 +284,7 @@ private[python] class PythonStateDStream( val lastState = getOrCompute(validTime - slideDuration) val rdd = parent.getOrCompute(validTime) if (rdd.isDefined) { - func(lastState, rdd, validTime) + func(lastState.orElse(initialRDD), rdd, validTime) } else { lastState } From 23a9e62bad9669e9ff5dc4bd714f58d12f9be0b5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 10 Dec 2015 15:29:04 -0800 Subject: [PATCH 1738/1824] [SPARK-12251] Document and improve off-heap memory configurations This patch adds documentation for Spark configurations that affect off-heap memory and makes some naming and validation improvements for those configs. - Change `spark.memory.offHeapSize` to `spark.memory.offHeap.size`. This is fine because this configuration has not shipped in any Spark release yet (it's new in Spark 1.6). - Deprecated `spark.unsafe.offHeap` in favor of a new `spark.memory.offHeap.enabled` configuration. The motivation behind this change is to gather all memory-related configurations under the same prefix. - Add a check which prevents users from setting `spark.memory.offHeap.enabled=true` when `spark.memory.offHeap.size == 0`. After SPARK-11389 (#9344), which was committed in Spark 1.6, Spark enforces a hard limit on the amount of off-heap memory that it will allocate to tasks. As a result, enabling off-heap execution memory without setting `spark.memory.offHeap.size` will lead to immediate OOMs. The new configuration validation makes this scenario easier to diagnose, helping to avoid user confusion. - Document these configurations on the configuration page. Author: Josh Rosen Closes #10237 from JoshRosen/SPARK-12251. --- .../scala/org/apache/spark/SparkConf.scala | 4 +++- .../apache/spark/memory/MemoryManager.scala | 10 +++++++-- .../spark/memory/TaskMemoryManagerSuite.java | 21 +++++++++++++++---- .../sort/PackedRecordPointerSuite.java | 6 ++++-- .../sort/ShuffleInMemorySorterSuite.java | 4 ++-- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../map/AbstractBytesToBytesMapSuite.java | 4 ++-- .../sort/UnsafeExternalSorterSuite.java | 2 +- .../sort/UnsafeInMemorySorterSuite.java | 4 ++-- .../memory/StaticMemoryManagerSuite.scala | 2 +- .../memory/UnifiedMemoryManagerSuite.scala | 2 +- docs/configuration.md | 16 ++++++++++++++ .../sql/execution/joins/HashedRelation.scala | 6 +++++- .../UnsafeFixedWidthAggregationMapSuite.scala | 2 +- .../UnsafeKVExternalSorterSuite.scala | 2 +- 15 files changed, 65 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 19633a3ce6a0..d3384fb29773 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -597,7 +597,9 @@ private[spark] object SparkConf extends Logging { "spark.streaming.fileStream.minRememberDuration" -> Seq( AlternateConfig("spark.streaming.minRememberDuration", "1.5")), "spark.yarn.max.executor.failures" -> Seq( - AlternateConfig("spark.yarn.max.worker.failures", "1.5")) + AlternateConfig("spark.yarn.max.worker.failures", "1.5")), + "spark.memory.offHeap.enabled" -> Seq( + AlternateConfig("spark.unsafe.offHeap", "1.6")) ) /** diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index ae9e1ac0e246..e707e27d96b5 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -50,7 +50,7 @@ private[spark] abstract class MemoryManager( storageMemoryPool.incrementPoolSize(storageMemory) onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) - offHeapExecutionMemoryPool.incrementPoolSize(conf.getSizeAsBytes("spark.memory.offHeapSize", 0)) + offHeapExecutionMemoryPool.incrementPoolSize(conf.getSizeAsBytes("spark.memory.offHeap.size", 0)) /** * Total available memory for storage, in bytes. This amount can vary over time, depending on @@ -182,7 +182,13 @@ private[spark] abstract class MemoryManager( * sun.misc.Unsafe. */ final val tungstenMemoryMode: MemoryMode = { - if (conf.getBoolean("spark.unsafe.offHeap", false)) MemoryMode.OFF_HEAP else MemoryMode.ON_HEAP + if (conf.getBoolean("spark.memory.offHeap.enabled", false)) { + require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0, + "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true") + MemoryMode.OFF_HEAP + } else { + MemoryMode.ON_HEAP + } } /** diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 711eed0193bc..776a2997cf91 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -29,7 +29,7 @@ public class TaskMemoryManagerSuite { public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "false"), + new SparkConf().set("spark.memory.offHeap.enabled", "false"), Long.MAX_VALUE, Long.MAX_VALUE, 1), @@ -41,8 +41,10 @@ public void leakedPageMemoryIsDetected() { @Test public void encodePageNumberAndOffsetOffHeap() { - final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0); + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = manager.allocatePage(256, null); // In off-heap mode, an offset is an absolute address that may require more than 51 bits to // encode. This test exercises that corner-case: @@ -55,7 +57,7 @@ public void encodePageNumberAndOffsetOffHeap() { @Test public void encodePageNumberAndOffsetOnHeap() { final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final MemoryBlock dataPage = manager.allocatePage(256, null); final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); @@ -104,4 +106,15 @@ public void cooperativeSpilling() { assert(manager.cleanUpAllAllocatedMemory() == 0); } + @Test + public void offHeapConfigurationBackwardsCompatibility() { + // Tests backwards-compatibility with the old `spark.unsafe.offHeap` configuration, which + // was deprecated in Spark 1.6 and replaced by `spark.memory.offHeap.enabled` (see SPARK-12251). + final SparkConf conf = new SparkConf() + .set("spark.unsafe.offHeap", "true") + .set("spark.memory.offHeap.size", "1000"); + final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); + assert(manager.tungstenMemoryMode == MemoryMode.OFF_HEAP); + } + } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 9a43f1f3a923..fe5abc5c2304 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -35,7 +35,7 @@ public class PackedRecordPointerSuite { @Test public void heap() throws IOException { - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128, null); @@ -54,7 +54,9 @@ public void heap() throws IOException { @Test public void offHeap() throws IOException { - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true"); + final SparkConf conf = new SparkConf() + .set("spark.memory.offHeap.enabled", "true") + .set("spark.memory.offHeap.size", "10000"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128, null); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index faa5a863ee63..0328e63e4543 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -34,7 +34,7 @@ public class ShuffleInMemorySorterSuite { final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); @@ -64,7 +64,7 @@ public void testBasicSorting() throws Exception { "Lychee", "Mango" }; - final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); + final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index bc85918c59aa..5fe64bde3604 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -108,7 +108,7 @@ public void setUp() throws IOException { spillFilesCreated.clear(); conf = new SparkConf() .set("spark.buffer.pageSize", "1m") - .set("spark.unsafe.offHeap", "false"); + .set("spark.memory.offHeap.enabled", "false"); taskMetrics = new TaskMetrics(); memoryManager = new TestMemoryManager(conf); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 8724a3498842..702ba5469b8b 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -85,8 +85,8 @@ public void setup() { memoryManager = new TestMemoryManager( new SparkConf() - .set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()) - .set("spark.memory.offHeapSize", "256mb")); + .set("spark.memory.offHeap.enabled", "" + useOffHeapMemoryAllocator()) + .set("spark.memory.offHeap.size", "256mb")); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a1c9f6fab8e6..e0ee281e98b7 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -58,7 +58,7 @@ public class UnsafeExternalSorterSuite { final LinkedList spillFilesCreated = new LinkedList(); final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index a203a09648ac..93efd033eb94 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -46,7 +46,7 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { final TaskMemoryManager memoryManager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, @@ -71,7 +71,7 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { "Mango" }; final TaskMemoryManager memoryManager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 272253bc94e9..68cf26fc3ed5 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -47,7 +47,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { conf.clone .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString), + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString), maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, maxStorageMemory = 0, numCores = 1) diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 71221deeb4c2..e21a028b7fae 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -42,7 +42,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val conf = new SparkConf() .set("spark.memory.fraction", "1") .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) - .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString) + .set("spark.memory.offHeap.size", maxOffHeapExecutionMemory.toString) .set("spark.memory.storageFraction", storageFraction.toString) UnifiedMemoryManager(conf, numCores = 1) } diff --git a/docs/configuration.md b/docs/configuration.md index 873a2d0b303c..55cf4b2dac5f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -738,6 +738,22 @@ Apart from these, the following properties are also available, and may be useful this description. + + + + + + + + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index aebfea583240..8c7099ab5a34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -334,7 +334,11 @@ private[joins] final class UnsafeHashedRelation( // so that tests compile: val taskMemoryManager = new TaskMemoryManager( new StaticMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "false"), Long.MaxValue, Long.MaxValue, 1), 0) + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 7ceaee38d131..5a8406789ab8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -61,7 +61,7 @@ class UnsafeFixedWidthAggregationMapSuite } test(name) { - val conf = new SparkConf().set("spark.unsafe.offHeap", "false") + val conf = new SparkConf().set("spark.memory.offHeap.enabled", "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 7b80963ec870..29027a664b4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -109,7 +109,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { pageSize: Long, spill: Boolean): Unit = { val memoryManager = - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")) + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")) val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, From 5030923ea8bb94ac8fa8e432de9fc7089aa93986 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 10 Dec 2015 15:30:08 -0800 Subject: [PATCH 1739/1824] [SPARK-12155][SPARK-12253] Fix executor OOM in unified memory management **Problem.** In unified memory management, acquiring execution memory may lead to eviction of storage memory. However, the space freed from evicting cached blocks is distributed among all active tasks. Thus, an incorrect upper bound on the execution memory per task can cause the acquisition to fail, leading to OOM's and premature spills. **Example.** Suppose total memory is 1000B, cached blocks occupy 900B, `spark.memory.storageFraction` is 0.4, and there are two active tasks. In this case, the cap on task execution memory is 100B / 2 = 50B. If task A tries to acquire 200B, it will evict 100B of storage but can only acquire 50B because of the incorrect cap. For another example, see this [regression test](https://github.com/andrewor14/spark/blob/fix-oom/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala#L233) that I stole from JoshRosen. **Solution.** Fix the cap on task execution memory. It should take into account the space that could have been freed by storage in addition to the current amount of memory available to execution. In the example above, the correct cap should have been 600B / 2 = 300B. This patch also guards against the race condition (SPARK-12253): (1) Existing tasks collectively occupy all execution memory (2) New task comes in and blocks while existing tasks spill (3) After tasks finish spilling, another task jumps in and puts in a large block, stealing the freed memory (4) New task still cannot acquire memory and goes back to sleep Author: Andrew Or Closes #10240 from andrewor14/fix-oom. --- .../spark/memory/ExecutionMemoryPool.scala | 57 +++++++++++++------ .../spark/memory/UnifiedMemoryManager.scala | 57 ++++++++++++++----- .../org/apache/spark/scheduler/Task.scala | 6 ++ .../memory/UnifiedMemoryManagerSuite.scala | 25 ++++++++ 4 files changed, 114 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala index 9023e1ac012b..dbb0ad8d5c67 100644 --- a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -70,11 +70,28 @@ private[memory] class ExecutionMemoryPool( * active tasks) before it is forced to spill. This can happen if the number of tasks increase * but an older task had a lot of memory already. * + * @param numBytes number of bytes to acquire + * @param taskAttemptId the task attempt acquiring memory + * @param maybeGrowPool a callback that potentially grows the size of this pool. It takes in + * one parameter (Long) that represents the desired amount of memory by + * which this pool should be expanded. + * @param computeMaxPoolSize a callback that returns the maximum allowable size of this pool + * at this given moment. This is not a field because the max pool + * size is variable in certain cases. For instance, in unified + * memory management, the execution pool can be expanded by evicting + * cached blocks, thereby shrinking the storage pool. + * * @return the number of bytes granted to the task. */ - def acquireMemory(numBytes: Long, taskAttemptId: Long): Long = lock.synchronized { + private[memory] def acquireMemory( + numBytes: Long, + taskAttemptId: Long, + maybeGrowPool: Long => Unit = (additionalSpaceNeeded: Long) => Unit, + computeMaxPoolSize: () => Long = () => poolSize): Long = lock.synchronized { assert(numBytes > 0, s"invalid number of bytes requested: $numBytes") + // TODO: clean up this clunky method signature + // Add this task to the taskMemory map just so we can keep an accurate count of the number // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory` if (!memoryForTask.contains(taskAttemptId)) { @@ -91,25 +108,31 @@ private[memory] class ExecutionMemoryPool( val numActiveTasks = memoryForTask.keys.size val curMem = memoryForTask(taskAttemptId) - // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; - // don't let it be negative - val maxToGrant = - math.min(numBytes, math.max(0, (poolSize / numActiveTasks) - curMem)) + // In every iteration of this loop, we should first try to reclaim any borrowed execution + // space from storage. This is necessary because of the potential race condition where new + // storage blocks may steal the free execution memory that this task was waiting for. + maybeGrowPool(numBytes - memoryFree) + + // Maximum size the pool would have after potentially growing the pool. + // This is used to compute the upper bound of how much memory each task can occupy. This + // must take into account potential free memory as well as the amount this pool currently + // occupies. Otherwise, we may run into SPARK-12155 where, in unified memory management, + // we did not take into account space that could have been freed by evicting cached blocks. + val maxPoolSize = computeMaxPoolSize() + val maxMemoryPerTask = maxPoolSize / numActiveTasks + val minMemoryPerTask = poolSize / (2 * numActiveTasks) + + // How much we can grant this task; keep its share within 0 <= X <= 1 / numActiveTasks + val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem)) // Only give it as much memory as is free, which might be none if it reached 1 / numTasks val toGrant = math.min(maxToGrant, memoryFree) - if (curMem < poolSize / (2 * numActiveTasks)) { - // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; - // if we can't give it this much now, wait for other tasks to free up memory - // (this happens if older tasks allocated lots of memory before N grew) - if (memoryFree >= math.min(maxToGrant, poolSize / (2 * numActiveTasks) - curMem)) { - memoryForTask(taskAttemptId) += toGrant - return toGrant - } else { - logInfo( - s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") - lock.wait() - } + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (toGrant < numBytes && curMem + toGrant < minMemoryPerTask) { + logInfo(s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") + lock.wait() } else { memoryForTask(taskAttemptId) += toGrant return toGrant diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 0b9f6a9dc052..829f054dba0e 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -81,22 +81,51 @@ private[spark] class UnifiedMemoryManager private[memory] ( assert(numBytes >= 0) memoryMode match { case MemoryMode.ON_HEAP => - if (numBytes > onHeapExecutionMemoryPool.memoryFree) { - val extraMemoryNeeded = numBytes - onHeapExecutionMemoryPool.memoryFree - // There is not enough free memory in the execution pool, so try to reclaim memory from - // storage. We can reclaim any free memory from the storage pool. If the storage pool - // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim - // the memory that storage has borrowed from execution. - val memoryReclaimableFromStorage = - math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) - if (memoryReclaimableFromStorage > 0) { - // Only reclaim as much space as is necessary and available: - val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( - math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) - onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + + /** + * Grow the execution pool by evicting cached blocks, thereby shrinking the storage pool. + * + * When acquiring memory for a task, the execution pool may need to make multiple + * attempts. Each attempt must be able to evict storage in case another task jumps in + * and caches a large block between the attempts. This is called once per attempt. + */ + def maybeGrowExecutionPool(extraMemoryNeeded: Long): Unit = { + if (extraMemoryNeeded > 0) { + // There is not enough free memory in the execution pool, so try to reclaim memory from + // storage. We can reclaim any free memory from the storage pool. If the storage pool + // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim + // the memory that storage has borrowed from execution. + val memoryReclaimableFromStorage = + math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) + if (memoryReclaimableFromStorage > 0) { + // Only reclaim as much space as is necessary and available: + val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( + math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) + onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + } } } - onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + + /** + * The size the execution pool would have after evicting storage memory. + * + * The execution memory pool divides this quantity among the active tasks evenly to cap + * the execution memory allocation for each task. It is important to keep this greater + * than the execution pool size, which doesn't take into account potential memory that + * could be freed by evicting storage. Otherwise we may hit SPARK-12155. + * + * Additionally, this quantity should be kept below `maxMemory` to arbitrate fairness + * in execution memory allocation across tasks, Otherwise, a task may occupy more than + * its fair share of execution memory, mistakenly thinking that other tasks can acquire + * the portion of storage memory that cannot be evicted. + */ + def computeMaxExecutionPoolSize(): Long = { + maxMemory - math.min(storageMemoryUsed, storageRegionSize) + } + + onHeapExecutionMemoryPool.acquireMemory( + numBytes, taskAttemptId, maybeGrowExecutionPool, computeMaxExecutionPoolSize) + case MemoryMode.OFF_HEAP => // For now, we only support on-heap caching of data, so we do not need to interact with // the storage pool when allocating off-heap memory. This will change in the future, though. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d4bc3a5c900f..9f27eed626be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -92,6 +92,12 @@ private[spark] abstract class Task[T]( Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } } } finally { TaskContext.unset() diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index e21a028b7fae..6cc48597d38f 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -230,4 +230,29 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(exception.getMessage.contains("larger heap size")) } + test("execution can evict cached blocks when there are multiple active tasks (SPARK-12155)") { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.memory.storageFraction", "0") + .set("spark.testing.memory", "1000") + val mm = UnifiedMemoryManager(conf, numCores = 2) + val ms = makeMemoryStore(mm) + assert(mm.maxMemory === 1000) + // Have two tasks each acquire some execution memory so that the memory pool registers that + // there are two active tasks: + assert(mm.acquireExecutionMemory(100L, 0, MemoryMode.ON_HEAP) === 100L) + assert(mm.acquireExecutionMemory(100L, 1, MemoryMode.ON_HEAP) === 100L) + // Fill up all of the remaining memory with storage. + assert(mm.acquireStorageMemory(dummyBlock, 800L, evictedBlocks)) + assertEvictBlocksToFreeSpaceNotCalled(ms) + assert(mm.storageMemoryUsed === 800) + assert(mm.executionMemoryUsed === 200) + // A task should still be able to allocate 100 bytes execution memory by evicting blocks + assert(mm.acquireExecutionMemory(100L, 0, MemoryMode.ON_HEAP) === 100L) + assertEvictBlocksToFreeSpaceCalled(ms, 100L) + assert(mm.executionMemoryUsed === 300) + assert(mm.storageMemoryUsed === 700) + assert(evictedBlocks.nonEmpty) + } + } From 24d3357d66e14388faf8709b368edca70ea96432 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 10 Dec 2015 15:31:46 -0800 Subject: [PATCH 1740/1824] [STREAMING][DOC][MINOR] Update the description of direct Kafka stream doc With the merge of [SPARK-8337](https://issues.apache.org/jira/browse/SPARK-8337), now the Python API has the same functionalities compared to Scala/Java, so here changing the description to make it more precise. zsxwing tdas , please review, thanks a lot. Author: jerryshao Closes #10246 from jerryshao/direct-kafka-doc-update. --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index b00351b2fbcc..5be73c42560f 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -74,7 +74,7 @@ Next, we discuss how to use this approach in your streaming application. [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) -This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. This approach has the following advantages over the receiver-based approach (i.e. Approach 1). From b1b4ee7f3541d92c8bc2b0b4fdadf46cfdb09504 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 10 Dec 2015 17:22:18 -0800 Subject: [PATCH 1741/1824] [SPARK-12258][SQL] passing null into ScalaUDF Check nullability and passing them into ScalaUDF. Closes #10249 Author: Davies Liu Closes #10259 from davies/udf_null. --- .../apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 7 +++++-- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 03b89221ef2d..5deb2f81d173 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1029,8 +1029,11 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val funcArguments = converterTerms.zip(evals).map { - case (converter, eval) => s"$converter.apply(${eval.value})" + val funcArguments = converterTerms.zipWithIndex.map { + case (converter, i) => + val eval = evals(i) + val dt = children(i).dataType + s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})" }.mkString(",") val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 605a6549dd68..8887dc68a50e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1138,14 +1138,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-11725: correctly handle null inputs for ScalaUDF") { - val df = Seq( + val df = sparkContext.parallelize(Seq( new java.lang.Integer(22) -> "John", - null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name") + null.asInstanceOf[java.lang.Integer] -> "Lucy")).toDF("age", "name") + // passing null into the UDF that could handle it val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { - (i: java.lang.Integer) => if (i == null) null else i * 2 + (i: java.lang.Integer) => if (i == null) -10 else i * 2 } - checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil) + checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil) val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) From 518ab5101073ee35d62e33c8f7281a1e6342101e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 11 Dec 2015 02:35:53 -0500 Subject: [PATCH 1742/1824] [SPARK-10991][ML] logistic regression training summary handle empty prediction col LogisticRegression training summary should still function if the predictionCol is set to an empty string or otherwise unset (related too https://issues.apache.org/jira/browse/SPARK-9718 ) Author: Holden Karau Author: Holden Karau Closes #9037 from holdenk/SPARK-10991-LogisticRegressionTrainingSummary-handle-empty-prediction-col. --- .../classification/LogisticRegression.scala | 20 +++++++++++++++++-- .../LogisticRegressionSuite.scala | 11 ++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 19cc323d5073..486043e8d974 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.unpersist() val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) + val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol() val logRegSummary = new BinaryLogisticRegressionTrainingSummary( - model.transform(dataset), - $(probabilityCol), + summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol), objectiveHistory) @@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] ( new NullPointerException()) } + /** + * If the probability column is set returns the current model and probability column, + * otherwise generates a new column and sets it as the probability column on a new copy + * of the current model. + */ + private[classification] def findSummaryModelAndProbabilityCol(): + (LogisticRegressionModel, String) = { + $(probabilityCol) match { + case "" => + val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) + case p => (this, p) + } + } + private[classification] def setSummary( summary: LogisticRegressionTrainingSummary): this.type = { this.trainingSummary = Some(summary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a9a6ff8a783d..1087afb0cdf7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -99,6 +99,17 @@ class LogisticRegressionSuite assert(model.hasParent) } + test("empty probabilityCol") { + val lr = new LogisticRegression().setProbabilityCol("") + val model = lr.fit(dataset) + assert(model.hasSummary) + // Validate that we re-insert a probability column for evaluation + val fieldNames = model.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + fieldNames.toSet)) + assert(fieldNames.exists(s => s.startsWith("probability_"))) + } + test("setThreshold, getThreshold") { val lr = new LogisticRegression // default From c119a34d1e9e599e302acfda92e5de681086a19f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 11 Dec 2015 11:15:53 -0800 Subject: [PATCH 1743/1824] [SPARK-12258] [SQL] passing null into ScalaUDF (follow-up) This is a follow-up PR for #10259 Author: Davies Liu Closes #10266 from davies/null_udf2. --- .../sql/catalyst/expressions/ScalaUDF.scala | 31 ++++++++++--------- .../org/apache/spark/sql/DataFrameSuite.scala | 8 +++-- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 5deb2f81d173..85faa19bbf5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1029,24 +1029,27 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val funcArguments = converterTerms.zipWithIndex.map { - case (converter, i) => - val eval = evals(i) - val dt = children(i).dataType - s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})" - }.mkString(",") - val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + - s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + - s".apply($funcTerm.apply($funcArguments));" + val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" + (convert, argTerm) + }.unzip - evalCode + s""" - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - Boolean ${ev.isNull}; + val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + + s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + + s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + s""" + $evalCode + ${converters.mkString("\n")} $callFunc - ${ev.value} = $resultTerm; - ${ev.isNull} = $resultTerm == null; + boolean ${ev.isNull} = $resultTerm == null; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $resultTerm; + } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8887dc68a50e..5353fefaf4b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1144,9 +1144,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // passing null into the UDF that could handle it val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { - (i: java.lang.Integer) => if (i == null) -10 else i * 2 + (i: java.lang.Integer) => if (i == null) -10 else null } - checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil) + checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) + + sqlContext.udf.register("boxedUDF", + (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) + checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) From 0fb9825556dbbcc98d7eafe9ddea8676301e09bb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Dec 2015 11:47:35 -0800 Subject: [PATCH 1744/1824] [SPARK-12146][SPARKR] SparkR jsonFile should support multiple input files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ```jsonFile``` should support multiple input files, such as: ```R jsonFile(sqlContext, c(“path1â€, “path2â€)) # character vector as arguments jsonFile(sqlContext, “path1,path2â€) ``` * Meanwhile, ```jsonFile``` has been deprecated by Spark SQL and will be removed at Spark 2.0. So we mark ```jsonFile``` deprecated and use ```read.json``` at SparkR side. * Replace all ```jsonFile``` with ```read.json``` at test_sparkSQL.R, but still keep jsonFile test case. * If this PR is accepted, we should also make almost the same change for ```parquetFile```. cc felixcheung sun-rui shivaram Author: Yanbo Liang Closes #10145 from yanboliang/spark-12146. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 102 +++++++++--------- R/pkg/R/SQLContext.R | 29 ++++-- R/pkg/inst/tests/testthat/test_sparkSQL.R | 120 ++++++++++++---------- examples/src/main/r/dataframe.R | 2 +- 5 files changed, 138 insertions(+), 116 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ba64bc59edee..cab39d68c3f5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -267,6 +267,7 @@ export("as.DataFrame", "createExternalTable", "dropTempTable", "jsonFile", + "read.json", "loadDF", "parquetFile", "read.df", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f4c4a2585e29..975b058c0aaf 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -24,14 +24,14 @@ setOldClass("jobj") #' @title S4 class that represents a DataFrame #' @description DataFrames can be created using functions like \link{createDataFrame}, -#' \link{jsonFile}, \link{table} etc. +#' \link{read.json}, \link{table} etc. #' @family DataFrame functions #' @rdname DataFrame #' @docType class #' #' @slot env An R environment that stores bookkeeping states of the DataFrame #' @slot sdf A Java object reference to the backing Scala DataFrame -#' @seealso \link{createDataFrame}, \link{jsonFile}, \link{table} +#' @seealso \link{createDataFrame}, \link{read.json}, \link{table} #' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} #' @export #' @examples @@ -77,7 +77,7 @@ dataFrame <- function(sdf, isCached = FALSE) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' printSchema(df) #'} setMethod("printSchema", @@ -102,7 +102,7 @@ setMethod("printSchema", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dfSchema <- schema(df) #'} setMethod("schema", @@ -126,7 +126,7 @@ setMethod("schema", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' explain(df, TRUE) #'} setMethod("explain", @@ -157,7 +157,7 @@ setMethod("explain", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' isLocal(df) #'} setMethod("isLocal", @@ -182,7 +182,7 @@ setMethod("isLocal", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' showDF(df) #'} setMethod("showDF", @@ -207,7 +207,7 @@ setMethod("showDF", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' df #'} setMethod("show", "DataFrame", @@ -234,7 +234,7 @@ setMethod("show", "DataFrame", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dtypes(df) #'} setMethod("dtypes", @@ -261,7 +261,7 @@ setMethod("dtypes", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' columns(df) #' colnames(df) #'} @@ -376,7 +376,7 @@ setMethod("coltypes", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' coltypes(df) <- c("character", "integer") #' coltypes(df) <- c(NA, "numeric") #'} @@ -423,7 +423,7 @@ setMethod("coltypes<-", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "json_df") #' new_df <- sql(sqlContext, "SELECT * FROM json_df") #'} @@ -476,7 +476,7 @@ setMethod("insertInto", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' cache(df) #'} setMethod("cache", @@ -504,7 +504,7 @@ setMethod("cache", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #'} setMethod("persist", @@ -532,7 +532,7 @@ setMethod("persist", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} @@ -560,7 +560,7 @@ setMethod("unpersist", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- repartition(df, 2L) #'} setMethod("repartition", @@ -585,7 +585,7 @@ setMethod("repartition", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newRDD <- toJSON(df) #'} setMethod("toJSON", @@ -613,7 +613,7 @@ setMethod("toJSON", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' saveAsParquetFile(df, "/tmp/sparkr-tmp/") #'} setMethod("saveAsParquetFile", @@ -637,7 +637,7 @@ setMethod("saveAsParquetFile", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' distinctDF <- distinct(df) #'} setMethod("distinct", @@ -672,7 +672,7 @@ setMethod("unique", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} @@ -711,7 +711,7 @@ setMethod("sample_frac", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' count(df) #' } setMethod("count", @@ -741,7 +741,7 @@ setMethod("nrow", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' ncol(df) #' } setMethod("ncol", @@ -762,7 +762,7 @@ setMethod("ncol", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' dim(df) #' } setMethod("dim", @@ -786,7 +786,7 @@ setMethod("dim", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } @@ -858,7 +858,7 @@ setMethod("collect", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' limitedDF <- limit(df, 10) #' } setMethod("limit", @@ -879,7 +879,7 @@ setMethod("limit", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' take(df, 2) #' } setMethod("take", @@ -908,7 +908,7 @@ setMethod("take", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' head(df) #' } setMethod("head", @@ -931,7 +931,7 @@ setMethod("head", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' first(df) #' } setMethod("first", @@ -952,7 +952,7 @@ setMethod("first", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' rdd <- toRDD(df) #'} setMethod("toRDD", @@ -1298,7 +1298,7 @@ setMethod("select", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } setMethod("selectExpr", @@ -1327,7 +1327,7 @@ setMethod("selectExpr", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' } setMethod("withColumn", @@ -1352,7 +1352,7 @@ setMethod("withColumn", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) @@ -1402,7 +1402,7 @@ setMethod("transform", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } setMethod("withColumnRenamed", @@ -1427,7 +1427,7 @@ setMethod("withColumnRenamed", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' newDF <- rename(df, col1 = df$newCol1) #' } setMethod("rename", @@ -1471,7 +1471,7 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' arrange(df, df$col1) #' arrange(df, asc(df$col1), desc(abs(df$col2))) #' arrange(df, "col1", decreasing = TRUE) @@ -1547,7 +1547,7 @@ setMethod("orderBy", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } @@ -1591,8 +1591,8 @@ setMethod("where", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") @@ -1648,8 +1648,8 @@ setMethod("join", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' merge(df1, df2) # Performs a Cartesian #' merge(df1, df2, by = "col1") # Performs an inner join based on expression #' merge(df1, df2, by.x = "col1", by.y = "col2", all.y = TRUE) @@ -1781,8 +1781,8 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' unioned <- unionAll(df, df2) #' } setMethod("unionAll", @@ -1824,8 +1824,8 @@ setMethod("rbind", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' intersectDF <- intersect(df, df2) #' } setMethod("intersect", @@ -1851,8 +1851,8 @@ setMethod("intersect", #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlContext, path) -#' df2 <- jsonFile(sqlContext, path2) +#' df1 <- read.json(sqlContext, path) +#' df2 <- read.json(sqlContext, path2) #' exceptDF <- except(df, df2) #' } #' @rdname except @@ -1892,7 +1892,7 @@ setMethod("except", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' write.df(df, "myfile", "parquet", "overwrite") #' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) #' } @@ -1957,7 +1957,7 @@ setMethod("saveDF", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", @@ -1998,7 +1998,7 @@ setMethod("saveAsTable", #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") @@ -2054,7 +2054,7 @@ setMethod("summary", #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' dropna(df) #' } setMethod("dropna", @@ -2108,7 +2108,7 @@ setMethod("na.omit", #' sc <- sparkR.init() #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- read.json(sqlCtx, path) #' fillna(df, 1) #' fillna(df, list("age" = 20, "name" = "unknown")) #' } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index f678c70a7a77..9243d70e66f7 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -208,24 +208,33 @@ setMethod("toDF", signature(x = "RDD"), #' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame +#' @rdname read.json +#' @name read.json #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) #' df <- jsonFile(sqlContext, path) #' } - -jsonFile <- function(sqlContext, path) { +read.json <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - path <- suppressWarnings(normalizePath(path)) - # Convert a string vector of paths to a string containing comma separated paths - path <- paste(path, collapse = ",") - sdf <- callJMethod(sqlContext, "jsonFile", path) + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } +#' @rdname read.json +#' @name jsonFile +#' @export +jsonFile <- function(sqlContext, path) { + .Deprecated("read.json") + read.json(sqlContext, path) +} + #' JSON RDD #' @@ -299,7 +308,7 @@ parquetFile <- function(sqlContext, ...) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' new_df <- sql(sqlContext, "SELECT * FROM table") #' } @@ -323,7 +332,7 @@ sql <- function(sqlContext, sqlQuery) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' new_df <- table(sqlContext, "table") #' } @@ -396,7 +405,7 @@ tableNames <- function(sqlContext, databaseName = NULL) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' cacheTable(sqlContext, "table") #' } @@ -418,7 +427,7 @@ cacheTable <- function(sqlContext, tableName) { #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(sqlContext, path) #' registerTempTable(df, "table") #' uncacheTable(sqlContext, "table") #' } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2051784427be..ed9b2c9d4d16 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -330,7 +330,7 @@ writeLines(mockLinesMapType, mapTypeJsonPath) test_that("Collect DataFrame with complex types", { # ArrayType - df <- jsonFile(sqlContext, complexTypeJsonPath) + df <- read.json(sqlContext, complexTypeJsonPath) ldf <- collect(df) expect_equal(nrow(ldf), 3) @@ -357,7 +357,7 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) # StructType - df <- jsonFile(sqlContext, mapTypeJsonPath) + df <- read.json(sqlContext, mapTypeJsonPath) expect_equal(dtypes(df), list(c("info", "struct"), c("name", "string"))) ldf <- collect(df) @@ -371,10 +371,22 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) }) -test_that("jsonFile() on a local file returns a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) +test_that("read.json()/jsonFile() on a local file returns a DataFrame", { + df <- read.json(sqlContext, jsonPath) expect_is(df, "DataFrame") expect_equal(count(df), 3) + # read.json()/jsonFile() works with multiple input paths + jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".json") + write.df(df, jsonPath2, "json", mode="overwrite") + jsonDF1 <- read.json(sqlContext, c(jsonPath, jsonPath2)) + expect_is(jsonDF1, "DataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath, jsonPath2))) + expect_is(jsonDF2, "DataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) }) test_that("jsonRDD() on a RDD with json string", { @@ -391,7 +403,7 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test cache, uncache and clearCache", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") cacheTable(sqlContext, "table1") uncacheTable(sqlContext, "table1") @@ -400,7 +412,7 @@ test_that("test cache, uncache and clearCache", { }) test_that("test tableNames and tables", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") expect_equal(length(tableNames(sqlContext)), 1) df <- tables(sqlContext) @@ -409,7 +421,7 @@ test_that("test tableNames and tables", { }) test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") expect_is(newdf, "DataFrame") @@ -445,7 +457,7 @@ test_that("insertInto() on a registered table", { }) test_that("table() returns a new DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) registerTempTable(df, "table1") tabledf <- table(sqlContext, "table1") expect_is(tabledf, "DataFrame") @@ -458,14 +470,14 @@ test_that("table() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) @@ -487,7 +499,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) @@ -505,7 +517,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern="spark-test", fileext=".tmp") - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) @@ -516,7 +528,7 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 row @@ -528,7 +540,7 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { }) test_that("collect() returns a data.frame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) expect_equal(names(rdf)[1], "age") @@ -550,14 +562,14 @@ test_that("collect() returns a data.frame", { }) test_that("limit() returns DataFrame with the correct number of rows", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) dfLimited <- limit(df, 2) expect_is(dfLimited, "DataFrame") expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_equal(nrow(collect(df)), nrow(take(df, 10))) expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) @@ -584,7 +596,7 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 row @@ -601,7 +613,7 @@ test_that("multiple pipeline transformations result in an RDD with the correct v }) test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_false(df@env$isCached) cache(df) expect_true(df@env$isCached) @@ -620,7 +632,7 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testSchema <- schema(df) expect_equal(length(testSchema$fields()), 2) expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") @@ -641,7 +653,7 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form }) test_that("names() colnames() set the column names", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) names(df) <- c("col1", "col2") expect_equal(colnames(df)[2], "col2") @@ -661,7 +673,7 @@ test_that("names() colnames() set the column names", { }) test_that("head() and first() return the correct data", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testHead <- head(df) expect_equal(nrow(testHead), 3) expect_equal(ncol(testHead), 2) @@ -694,7 +706,7 @@ test_that("distinct() and unique on DataFrames", { jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlContext, jsonPathWithDup) + df <- read.json(sqlContext, jsonPathWithDup) uniques <- distinct(df) expect_is(uniques, "DataFrame") expect_equal(count(uniques), 3) @@ -705,7 +717,7 @@ test_that("distinct() and unique on DataFrames", { }) test_that("sample on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_is(sampled, "DataFrame") @@ -721,7 +733,7 @@ test_that("sample on a DataFrame", { }) test_that("select operators", { - df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + df <- select(read.json(sqlContext, jsonPath), "name", "age") expect_is(df$name, "Column") expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") @@ -747,7 +759,7 @@ test_that("select operators", { }) test_that("select with column", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) df1 <- select(df, "name") expect_equal(columns(df1), c("name")) expect_equal(count(df1), 3) @@ -770,8 +782,8 @@ test_that("select with column", { }) test_that("subsetting", { - # jsonFile returns columns in random order - df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + # read.json returns columns in random order + df <- select(read.json(sqlContext, jsonPath), "name", "age") filtered <- df[df$age > 20,] expect_equal(count(filtered), 1) expect_equal(columns(filtered), c("name", "age")) @@ -808,7 +820,7 @@ test_that("subsetting", { }) test_that("selectExpr() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) @@ -819,12 +831,12 @@ test_that("selectExpr() on a DataFrame", { }) test_that("expr() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) }) test_that("column calculation", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) @@ -915,7 +927,7 @@ test_that("column functions", { expect_equal(class(rank())[[1]], "Column") expect_equal(rank(1:3), as.numeric(c(1:3))) - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) expect_equal(collect(df2)[[2, 1]], TRUE) expect_equal(collect(df2)[[2, 2]], FALSE) @@ -983,7 +995,7 @@ test_that("column binary mathfunctions", { "{\"a\":4, \"b\":8}") jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlContext, jsonPathWithDup) + df <- read.json(sqlContext, jsonPathWithDup) expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) @@ -1004,7 +1016,7 @@ test_that("column binary mathfunctions", { }) test_that("string operators", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_equal(count(where(df, like(df$name, "A%"))), 1) expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") @@ -1100,7 +1112,7 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { }) test_that("group by, agg functions", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) @@ -1145,7 +1157,7 @@ test_that("group by, agg functions", { "{\"name\":\"ID2\", \"value\": \"-3\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - gd2 <- groupBy(jsonFile(sqlContext, jsonPath2), "name") + gd2 <- groupBy(read.json(sqlContext, jsonPath2), "name") df6 <- agg(gd2, value = "sum") df6_local <- collect(df6) expect_equal(42, df6_local[df6_local$name == "ID1",][1, 2]) @@ -1162,7 +1174,7 @@ test_that("group by, agg functions", { "{\"name\":\"Justin\", \"age\":1}") jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines3, jsonPath3) - df8 <- jsonFile(sqlContext, jsonPath3) + df8 <- read.json(sqlContext, jsonPath3) gd3 <- groupBy(df8, "name") gd3_local <- collect(sum(gd3)) expect_equal(60, gd3_local[gd3_local$name == "Andy",][1, 2]) @@ -1181,7 +1193,7 @@ test_that("group by, agg functions", { }) test_that("arrange() and orderBy() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) sorted <- arrange(df, df$age) expect_equal(collect(sorted)[1,2], "Michael") @@ -1207,7 +1219,7 @@ test_that("arrange() and orderBy() on a DataFrame", { }) test_that("filter() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) filtered <- filter(df, "age > 20") expect_equal(count(filtered), 1) expect_equal(collect(filtered)$name, "Andy") @@ -1230,7 +1242,7 @@ test_that("filter() on a DataFrame", { }) test_that("join() and merge() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", @@ -1238,7 +1250,7 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"test\": \"yes\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - df2 <- jsonFile(sqlContext, jsonPath2) + df2 <- read.json(sqlContext, jsonPath2) joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) @@ -1313,14 +1325,14 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines3, jsonPath3) - df3 <- jsonFile(sqlContext, jsonPath3) + df3 <- read.json(sqlContext, jsonPath3) expect_error(merge(df, df3), paste("The following column name: name_y occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.", sep = "")) }) test_that("toJSON() returns an RDD of the correct values", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) testRDD <- toJSON(df) expect_is(testRDD, "RDD") expect_equal(getSerializedMode(testRDD), "string") @@ -1328,7 +1340,7 @@ test_that("toJSON() returns an RDD of the correct values", { }) test_that("showDF()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) s <- capture.output(showDF(df)) expected <- paste("+----+-------+\n", "| age| name|\n", @@ -1341,12 +1353,12 @@ test_that("showDF()", { }) test_that("isLocal()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_false(isLocal(df)) }) test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", @@ -1383,7 +1395,7 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { }) test_that("withColumn() and withColumnRenamed()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1395,7 +1407,7 @@ test_that("withColumn() and withColumnRenamed()", { }) test_that("mutate(), transform(), rename() and names()", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1425,7 +1437,7 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("write.df() on DataFrame and works with read.parquet", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- read.parquet(sqlContext, parquetPath) expect_is(parquetDF, "DataFrame") @@ -1433,7 +1445,7 @@ test_that("write.df() on DataFrame and works with read.parquet", { }) test_that("read.parquet()/parquetFile() works with multiple input paths", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") @@ -1452,7 +1464,7 @@ test_that("read.parquet()/parquetFile() works with multiple input paths", { }) test_that("describe() and summarize() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") @@ -1470,7 +1482,7 @@ test_that("describe() and summarize() on a DataFrame", { }) test_that("dropna() and na.omit() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) + df <- read.json(sqlContext, jsonPathNa) rows <- collect(df) # drop with columns @@ -1556,7 +1568,7 @@ test_that("dropna() and na.omit() on a DataFrame", { }) test_that("fillna() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPathNa) + df <- read.json(sqlContext, jsonPathNa) rows <- collect(df) # fill with value @@ -1665,7 +1677,7 @@ test_that("Method as.data.frame as a synonym for collect()", { }) test_that("attach() on a DataFrame", { - df <- jsonFile(sqlContext, jsonPath) + df <- read.json(sqlContext, jsonPath) expect_error(age) attach(df) expect_is(age, "DataFrame") @@ -1713,7 +1725,7 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { list("a"="b", "c"="d", "e"="f"))))) expect_equal(coltypes(x), "map") - df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age") + df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) df1 <- select(df, cast(df$age, "integer")) diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 53b817144f6a..62f60e57eebe 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -35,7 +35,7 @@ printSchema(df) # Create a DataFrame from a JSON file path <- file.path(Sys.getenv("SPARK_HOME"), "examples/src/main/resources/people.json") -peopleDF <- jsonFile(sqlContext, path) +peopleDF <- read.json(sqlContext, path) printSchema(peopleDF) # Register this DataFrame as a table. From aa305dcaf5b4148aba9e669e081d0b9235f50857 Mon Sep 17 00:00:00 2001 From: anabranch Date: Fri, 11 Dec 2015 12:55:56 -0800 Subject: [PATCH 1745/1824] [SPARK-11964][DOCS][ML] Add in Pipeline Import/Export Documentation Adding in Pipeline Import and Export Documentation. Author: anabranch Author: Bill Chambers Closes #10179 from anabranch/master. --- docs/ml-guide.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 5c96c2b7d5cc..44a316a07dfe 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -192,6 +192,10 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. +## Saving and Loading Pipelines + +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. + # Code examples This section gives code examples illustrating the functionality discussed above. @@ -455,6 +459,15 @@ val pipeline = new Pipeline() // Fit the pipeline to training documents. val model = pipeline.fit(training) +// now we can optionally save the fitted pipeline to disk +model.save("/tmp/spark-logistic-regression-model") + +// we can also save this unfit pipeline to disk +pipeline.save("/tmp/unfit-lr-model") + +// and load it back in during production +val sameModel = Pipeline.load("/tmp/spark-logistic-regression-model") + // Prepare test documents, which are unlabeled (id, text) tuples. val test = sqlContext.createDataFrame(Seq( (4L, "spark i j k"), From 713e6959d21d24382ef99bbd7e9da751a7ed388c Mon Sep 17 00:00:00 2001 From: proflin Date: Fri, 11 Dec 2015 13:50:36 -0800 Subject: [PATCH 1746/1824] [SPARK-12273][STREAMING] Make Spark Streaming web UI list Receivers in order Currently the Streaming web UI does NOT list Receivers in order; however, it seems more convenient for the users if Receivers are listed in order. ![spark-12273](https://cloud.githubusercontent.com/assets/15843379/11736602/0bb7f7a8-a00b-11e5-8e86-96ba9297fb12.png) Author: proflin Closes #10264 from proflin/Spark-12273. --- .../scala/org/apache/spark/streaming/ui/StreamingPage.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 4588b2163cd4..88a4483e8068 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -392,8 +392,9 @@ private[ui] class StreamingPage(parent: StreamingTab) maxX: Long, minY: Double, maxY: Double): Seq[Node] = { - val content = listener.receivedEventRateWithBatchTime.map { case (streamId, eventRates) => - generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) + val content = listener.receivedEventRateWithBatchTime.toList.sortBy(_._1).map { + case (streamId, eventRates) => + generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) }.foldLeft[Seq[Node]](Nil)(_ ++ _) // scalastyle:off From 1b8220387e6903564f765fabb54be0420c3e99d7 Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Fri, 11 Dec 2015 14:21:33 -0800 Subject: [PATCH 1747/1824] [SPARK-11497][MLLIB][PYTHON] PySpark RowMatrix Constructor Has Type Erasure Issue As noted in PR #9441, implementing `tallSkinnyQR` uncovered a bug with our PySpark `RowMatrix` constructor. As discussed on the dev list [here](http://apache-spark-developers-list.1001551.n3.nabble.com/K-Means-And-Class-Tags-td10038.html), there appears to be an issue with type erasure with RDDs coming from Java, and by extension from PySpark. Although we are attempting to construct a `RowMatrix` from an `RDD[Vector]` in [PythonMLlibAPI](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala#L1115), the `Vector` type is erased, resulting in an `RDD[Object]`. Thus, when calling Scala's `tallSkinnyQR` from PySpark, we get a Java `ClassCastException` in which an `Object` cannot be cast to a Spark `Vector`. As noted in the aforementioned dev list thread, this issue was also encountered with `DecisionTrees`, and the fix involved an explicit `retag` of the RDD with a `Vector` type. `IndexedRowMatrix` and `CoordinateMatrix` do not appear to have this issue likely due to their related helper functions in `PythonMLlibAPI` creating the RDDs explicitly from DataFrames with pattern matching, thus preserving the types. This PR currently contains that retagging fix applied to the `createRowMatrix` helper function in `PythonMLlibAPI`. This PR blocks #9441, so once this is merged, the other can be rebased. cc holdenk Author: Mike Dusenberry Closes #9458 from dusenberrymw/SPARK-11497_PySpark_RowMatrix_Constructor_Has_Type_Erasure_Issue. --- .../org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 2aa6aec0b434..8d546e3d6099 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1143,7 +1143,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Wrapper around RowMatrix constructor. */ def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = { - new RowMatrix(rows.rdd, numRows, numCols) + new RowMatrix(rows.rdd.retag(classOf[Vector]), numRows, numCols) } /** From aea676ca2d07c72b1a752e9308c961118e5bfc3c Mon Sep 17 00:00:00 2001 From: BenFradet Date: Fri, 11 Dec 2015 15:43:00 -0800 Subject: [PATCH 1748/1824] [SPARK-12217][ML] Document invalid handling for StringIndexer Added a paragraph regarding StringIndexer#setHandleInvalid to the ml-features documentation. I wonder if I should also add a snippet to the code example, input welcome. Author: BenFradet Closes #10257 from BenFradet/SPARK-12217. --- docs/ml-features.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 6494fed0a01e..8b00cc652dc7 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -459,6 +459,42 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. +Additionaly, there are two strategies regarding how `StringIndexer` will handle +unseen labels when you have fit a `StringIndexer` on one dataset and then use it +to transform another: + +- throw an exception (which is the default) +- skip the row containing the unseen label entirely + +**Examples** + +Let's go back to our previous example but this time reuse our previously defined +`StringIndexer` on the following dataset: + +~~~~ + id | category +----|---------- + 0 | a + 1 | b + 2 | c + 3 | d +~~~~ + +If you've not set how `StringIndexer` handles unseen labels or set it to +"error", an exception will be thrown. +However, if you had called `setHandleInvalid("skip")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 +~~~~ + +Notice that the row containing "d" does not appear. +
    From a0ff6d16ef4bcc1b6ff7282e82a9b345d8449454 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Dec 2015 18:02:24 -0800 Subject: [PATCH 1749/1824] [SPARK-11978][ML] Move dataset_example.py to examples/ml and rename to dataframe_example.py Since ```Dataset``` has a new meaning in Spark 1.6, we should rename it to avoid confusion. #9873 finished the work of Scala example, here we focus on the Python one. Move dataset_example.py to ```examples/ml``` and rename to ```dataframe_example.py```. BTW, fix minor missing issues of #9873. cc mengxr Author: Yanbo Liang Closes #9957 from yanboliang/SPARK-11978. --- .../dataframe_example.py} | 56 +++++++++++-------- .../spark/examples/ml/DataFrameExample.scala | 8 +-- 2 files changed, 38 insertions(+), 26 deletions(-) rename examples/src/main/python/{mllib/dataset_example.py => ml/dataframe_example.py} (53%) diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/ml/dataframe_example.py similarity index 53% rename from examples/src/main/python/mllib/dataset_example.py rename to examples/src/main/python/ml/dataframe_example.py index e23ecc0c5d30..d2644ca33565 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -16,8 +16,8 @@ # """ -An example of how to use DataFrame as a dataset for ML. Run with:: - bin/spark-submit examples/src/main/python/mllib/dataset_example.py +An example of how to use DataFrame for ML. Run with:: + bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ from __future__ import print_function @@ -28,36 +28,48 @@ from pyspark import SparkContext from pyspark.sql import SQLContext -from pyspark.mllib.util import MLUtils from pyspark.mllib.stat import Statistics - -def summarize(dataset): - print("schema: %s" % dataset.schema().json()) - labels = dataset.map(lambda r: r.label) - print("label average: %f" % labels.mean()) - features = dataset.map(lambda r: r.features) - summary = Statistics.colStats(features) - print("features average: %r" % summary.mean()) - if __name__ == "__main__": if len(sys.argv) > 2: - print("Usage: dataset_example.py ", file=sys.stderr) + print("Usage: dataframe_example.py ", file=sys.stderr) exit(-1) - sc = SparkContext(appName="DatasetExample") + sc = SparkContext(appName="DataFrameExample") sqlContext = SQLContext(sc) if len(sys.argv) == 2: input = sys.argv[1] else: input = "data/mllib/sample_libsvm_data.txt" - points = MLUtils.loadLibSVMFile(sc, input) - dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() - summarize(dataset0) + + # Load input data + print("Loading LIBSVM file with UDT from " + input + ".") + df = sqlContext.read.format("libsvm").load(input).cache() + print("Schema from LIBSVM:") + df.printSchema() + print("Loaded training data as a DataFrame with " + + str(df.count()) + " records.") + + # Show statistical summary of labels. + labelSummary = df.describe("label") + labelSummary.show() + + # Convert features column to an RDD of vectors. + features = df.select("features").map(lambda r: r.features) + summary = Statistics.colStats(features) + print("Selected features column with average values:\n" + + str(summary.mean())) + + # Save the records in a parquet file. tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) - print("Save dataset as a Parquet file to %s." % tempdir) - dataset0.saveAsParquetFile(tempdir) - print("Load it back and summarize it again.") - dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() - summarize(dataset1) + print("Saving to " + tempdir + " as Parquet file.") + df.write.parquet(tempdir) + + # Load the records back. + print("Loading Parquet file with UDT from " + tempdir) + newDF = sqlContext.read.parquet(tempdir) + print("Schema from Parquet:") + newDF.printSchema() shutil.rmtree(tempdir) + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 424f00158c2f..0a477abae567 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -44,10 +44,10 @@ object DataFrameExample { def main(args: Array[String]) { val defaultParams = Params() - val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using DataFrame as a Dataset for ML.") + val parser = new OptionParser[Params]("DataFrameExample") { + head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataset") + .text(s"input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -88,7 +88,7 @@ object DataFrameExample { // Save the records in a parquet file. val tmpDir = Files.createTempDir() tmpDir.deleteOnExit() - val outputDir = new File(tmpDir, "dataset").toString + val outputDir = new File(tmpDir, "dataframe").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) From 1e799d617a28cd0eaa8f22d103ea8248c4655ae5 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Fri, 11 Dec 2015 19:07:48 -0800 Subject: [PATCH 1750/1824] [SPARK-12298][SQL] Fix infinite loop in DataFrame.sortWithinPartitions Modifies the String overload to call the Column overload and ensures this is called in a test. Author: Ankur Dave Closes #10271 from ankurdave/SPARK-12298. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index da180a2ba09d..497bd4826677 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -609,7 +609,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = { - sortWithinPartitions(sortCol, sortCols : _*) + sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5353fefaf4b8..c0bbf73ab118 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1090,8 +1090,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } // Distribute into one partition and order by. This partition should contain all the values. - val df6 = data.repartition(1, $"a").sortWithinPartitions($"b".asc) - // Walk each partition and verify that it is sorted descending and not globally sorted. + val df6 = data.repartition(1, $"a").sortWithinPartitions("b") + // Walk each partition and verify that it is sorted ascending and not globally sorted. df6.rdd.foreachPartition { p => var previousValue: Int = -1 var allSequential: Boolean = true From 1e3526c2d3de723225024fedd45753b556e18fc6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 11 Dec 2015 20:55:16 -0800 Subject: [PATCH 1751/1824] [SPARK-12158][SPARKR][SQL] Fix 'sample' functions that break R unit test cases The existing sample functions miss the parameter `seed`, however, the corresponding function interface in `generics` has such a parameter. Thus, although the function caller can call the function with the 'seed', we are not using the value. This could cause SparkR unit tests failed. For example, I hit it in another PR: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/47213/consoleFull Author: gatorsmile Closes #10160 from gatorsmile/sampleR. --- R/pkg/R/DataFrame.R | 17 +++++++++++------ R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 ++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 975b058c0aaf..764597d1e32b 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -662,6 +662,7 @@ setMethod("unique", #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value #' #' @family DataFrame functions #' @rdname sample @@ -677,13 +678,17 @@ setMethod("unique", #' collect(sample(df, TRUE, 0.5)) #'} setMethod("sample", - # TODO : Figure out how to send integer as java.lang.Long to JVM so - # we can send seed as an argument through callJMethod signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { + function(x, withReplacement, fraction, seed) { if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + if (!missing(seed)) { + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed)) + } else { + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + } dataFrame(sdf) }) @@ -692,8 +697,8 @@ setMethod("sample", setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { - sample(x, withReplacement, fraction) + function(x, withReplacement, fraction, seed) { + sample(x, withReplacement, fraction, seed) }) #' nrow diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ed9b2c9d4d16..071fd310fd58 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -724,6 +724,10 @@ test_that("sample on a DataFrame", { sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled2) < 3) + count1 <- count(sample(df, FALSE, 0.1, 0)) + count2 <- count(sample(df, FALSE, 0.1, 0)) + expect_equal(count1, count2) + # Also test sample_frac sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) From 03138b67d3ef7f5278ea9f8b9c75f0e357ef79d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Sat, 12 Dec 2015 08:51:52 +0000 Subject: [PATCH 1752/1824] [SPARK-11193] Use Java ConcurrentHashMap instead of SynchronizedMap trait in order to avoid ClassCastException due to KryoSerializer in KinesisReceiver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Jean-Baptiste Onofré Closes #10203 from jbonofre/SPARK-11193. --- .../streaming/kinesis/KinesisReceiver.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 05080835fc4a..80edda59e171 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.kinesis import java.util.UUID +import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -124,8 +125,7 @@ private[kinesis] class KinesisReceiver[T]( private val seqNumRangesInCurrentBlock = new mutable.ArrayBuffer[SequenceNumberRange] /** Sequence number ranges of data added to each generated block */ - private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges] - with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges] + private val blockIdToSeqNumRanges = new ConcurrentHashMap[StreamBlockId, SequenceNumberRanges] /** * The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval. @@ -135,8 +135,8 @@ private[kinesis] class KinesisReceiver[T]( /** * Latest sequence number ranges that have been stored successfully. * This is used for checkpointing through KCL */ - private val shardIdToLatestStoredSeqNum = new mutable.HashMap[String, String] - with mutable.SynchronizedMap[String, String] + private val shardIdToLatestStoredSeqNum = new ConcurrentHashMap[String, String] + /** * This is called when the KinesisReceiver starts and must be non-blocking. * The KCL creates and manages the receiving/processing thread pool through Worker.run(). @@ -222,7 +222,7 @@ private[kinesis] class KinesisReceiver[T]( /** Get the latest sequence number for the given shard that can be checkpointed through KCL */ private[kinesis] def getLatestSeqNumToCheckpoint(shardId: String): Option[String] = { - shardIdToLatestStoredSeqNum.get(shardId) + Option(shardIdToLatestStoredSeqNum.get(shardId)) } /** @@ -257,7 +257,7 @@ private[kinesis] class KinesisReceiver[T]( * for next block. Internally, this is synchronized with `rememberAddedRange()`. */ private def finalizeRangesForCurrentBlock(blockId: StreamBlockId): Unit = { - blockIdToSeqNumRanges(blockId) = SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray) + blockIdToSeqNumRanges.put(blockId, SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray)) seqNumRangesInCurrentBlock.clear() logDebug(s"Generated block $blockId has $blockIdToSeqNumRanges") } @@ -265,7 +265,7 @@ private[kinesis] class KinesisReceiver[T]( /** Store the block along with its associated ranges */ private def storeBlockWithRanges( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = { - val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId) + val rangesToReportOption = Option(blockIdToSeqNumRanges.remove(blockId)) if (rangesToReportOption.isEmpty) { stop("Error while storing block into Spark, could not find sequence number ranges " + s"for block $blockId") @@ -294,7 +294,7 @@ private[kinesis] class KinesisReceiver[T]( // Note that we are doing this sequentially because the array of sequence number ranges // is assumed to be rangesToReport.ranges.foreach { range => - shardIdToLatestStoredSeqNum(range.shardId) = range.toSeqNumber + shardIdToLatestStoredSeqNum.put(range.shardId, range.toSeqNumber) } } From 98b212d36b34ab490c391ea2adf5b141e4fb9289 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Sat, 12 Dec 2015 17:47:01 -0800 Subject: [PATCH 1753/1824] [SPARK-12199][DOC] Follow-up: Refine example code in ml-features.md https://issues.apache.org/jira/browse/SPARK-12199 Follow-up PR of SPARK-11551. Fix some errors in ml-features.md mengxr Author: Xusen Yin Closes #10193 from yinxusen/SPARK-12199. --- docs/ml-features.md | 22 +++++++++---------- .../examples/ml/JavaBinarizerExample.java | 2 +- .../python/ml/polynomial_expansion_example.py | 6 ++--- ....scala => ElementwiseProductExample.scala} | 0 4 files changed, 15 insertions(+), 15 deletions(-) rename examples/src/main/scala/org/apache/spark/examples/ml/{ElementWiseProductExample.scala => ElementwiseProductExample.scala} (100%) diff --git a/docs/ml-features.md b/docs/ml-features.md index 8b00cc652dc7..158f3f201899 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -63,7 +63,7 @@ the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for mor `Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` transforms each document into a vector using the average of all words in the document; this vector can then be used for as features for prediction, document similarity calculations, etc. -Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2Vec) for more details. In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. @@ -411,7 +411,7 @@ for more details on the API. Refer to the [DCT Java docs](api/java/org/apache/spark/ml/feature/DCT.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}} +{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}
    @@ -669,7 +669,7 @@ for more details on the API. The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^2$ norm and unit $L^\infty$ norm.
    -
    +
    Refer to the [Normalizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Normalizer) for more details on the API. @@ -677,7 +677,7 @@ for more details on the API. {% include_example scala/org/apache/spark/examples/ml/NormalizerExample.scala %}
    -
    +
    Refer to the [Normalizer Java docs](api/java/org/apache/spark/ml/feature/Normalizer.html) for more details on the API. @@ -685,7 +685,7 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaNormalizerExample.java %}
    -
    +
    Refer to the [Normalizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Normalizer) for more details on the API. @@ -709,7 +709,7 @@ Note that if the standard deviation of a feature is zero, it will return default The following example demonstrates how to load a dataset in libsvm format and then normalize each feature to have unit standard deviation.
    -
    +
    Refer to the [StandardScaler Scala docs](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler) for more details on the API. @@ -717,7 +717,7 @@ for more details on the API. {% include_example scala/org/apache/spark/examples/ml/StandardScalerExample.scala %}
    -
    +
    Refer to the [StandardScaler Java docs](api/java/org/apache/spark/ml/feature/StandardScaler.html) for more details on the API. @@ -725,7 +725,7 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaStandardScalerExample.java %}
    -
    +
    Refer to the [StandardScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.StandardScaler) for more details on the API. @@ -788,7 +788,7 @@ More details can be found in the API docs for [Bucketizer](api/scala/index.html# The following example demonstrates how to bucketize a column of `Double`s into another index-wised column.
    -
    +
    Refer to the [Bucketizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer) for more details on the API. @@ -796,7 +796,7 @@ for more details on the API. {% include_example scala/org/apache/spark/examples/ml/BucketizerExample.scala %}
    -
    +
    Refer to the [Bucketizer Java docs](api/java/org/apache/spark/ml/feature/Bucketizer.html) for more details on the API. @@ -804,7 +804,7 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaBucketizerExample.java %}
    -
    +
    Refer to the [Bucketizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Bucketizer) for more details on the API. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java index 9698cac50437..1eda1f694fc2 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -59,7 +59,7 @@ public static void main(String[] args) { DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame); DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature"); for (Row r : binarizedFeatures.collect()) { - Double binarized_value = r.getDouble(0); + Double binarized_value = r.getDouble(0); System.out.println(binarized_value); } // $example off$ diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py index 3d4fafd1a42e..89f5cbe8f2f4 100644 --- a/examples/src/main/python/ml/polynomial_expansion_example.py +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -30,9 +30,9 @@ # $example on$ df = sqlContext\ - .createDataFrame([(Vectors.dense([-2.0, 2.3]), ), - (Vectors.dense([0.0, 0.0]), ), - (Vectors.dense([0.6, -1.1]), )], + .createDataFrame([(Vectors.dense([-2.0, 2.3]),), + (Vectors.dense([0.0, 0.0]),), + (Vectors.dense([0.6, -1.1]),)], ["features"]) px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") polyDF = px.transform(df) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/ml/ElementWiseProductExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala From 8af2f8c61ae4a59d129fb3530d0f6e9317f4bff8 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 12 Dec 2015 21:58:55 -0800 Subject: [PATCH 1754/1824] [SPARK-12267][CORE] Store the remote RpcEnv address to send the correct disconnetion message Author: Shixiong Zhu Closes #10261 from zsxwing/SPARK-12267. --- .../spark/deploy/master/ApplicationInfo.scala | 1 + .../apache/spark/deploy/worker/Worker.scala | 2 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 21 ++++++++++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 42 +++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ac553b71115d..7e2cf956c725 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -66,6 +66,7 @@ private[spark] class ApplicationInfo( nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] executorLimit = Integer.MAX_VALUE + appUIUrlAtHistoryServer = None } private def newExecutorId(useID: Option[Int] = None): Int = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 1afc1ff59f2f..f41efb097b4b 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -690,7 +690,7 @@ private[deploy] object Worker extends Logging { val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, - args.memory, args.masters, args.workDir) + args.memory, args.masters, args.workDir, conf = conf) rpcEnv.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 68c5f44145b0..f82fd4eb5756 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -553,6 +553,9 @@ private[netty] class NettyRpcHandler( // A variable to track whether we should dispatch the RemoteProcessConnected message. private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() + // A variable to track the remote RpcEnv addresses of all clients + private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() + override def receive( client: TransportClient, message: ByteBuffer, @@ -580,6 +583,12 @@ private[netty] class NettyRpcHandler( // Create a new message with the socket address of the client as the sender. RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) } else { + // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for + // the listening address + val remoteEnvAddress = requestMessage.senderAddress + if (remoteAddresses.putIfAbsent(clientAddr, remoteEnvAddress) == null) { + dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) + } requestMessage } } @@ -591,6 +600,12 @@ private[netty] class NettyRpcHandler( if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessConnectionError for the remote RpcEnv listening address + val remoteEnvAddress = remoteAddresses.get(clientAddr) + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessConnectionError(cause, remoteEnvAddress)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. // See java.net.Socket.getRemoteSocketAddress @@ -606,6 +621,12 @@ private[netty] class NettyRpcHandler( val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) + val remoteEnvAddress = remoteAddresses.remove(clientAddr) + // If the remove RpcEnv listens to some address, we should also fire a + // RemoteProcessDisconnected for the remote RpcEnv listening address + if (remoteEnvAddress != null) { + dispatcher.postToAll(RemoteProcessDisconnected(remoteEnvAddress)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index a61d0479aacd..6d153eb04e04 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -545,6 +545,48 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("network events between non-client-mode RpcEnvs") { + val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] + env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { + override val rpcEnv = env + + override def receive: PartialFunction[Any, Unit] = { + case "hello" => + case m => events += "receive" -> m + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events += "onConnected" -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events += "onDisconnected" -> remoteAddress + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events += "onNetworkError" -> remoteAddress + } + + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "network-events-non-client") + val remoteAddress = anotherEnv.address + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", remoteAddress))) + } + + anotherEnv.shutdown() + anotherEnv.awaitTermination() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", remoteAddress))) + assert(events.contains(("onDisconnected", remoteAddress))) + } + } + test("sendWithReply: unserializable error") { env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint { override val rpcEnv = env From 2aecda284e22ec608992b6221e2f5ffbd51fcd24 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sun, 13 Dec 2015 22:06:39 -0800 Subject: [PATCH 1755/1824] [SPARK-12281][CORE] Fix a race condition when reporting ExecutorState in the shutdown hook 1. Make sure workers and masters exit so that no worker or master will still be running when triggering the shutdown hook. 2. Set ExecutorState to FAILED if it's still RUNNING when executing the shutdown hook. This should fix the potential exceptions when exiting a local cluster ``` java.lang.AssertionError: assertion failed: executor 4 state transfer from RUNNING to RUNNING is illegal at scala.Predef$.assert(Predef.scala:179) at org.apache.spark.deploy.master.Master$$anonfun$receive$1.applyOrElse(Master.scala:260) at org.apache.spark.rpc.netty.Inbox$$anonfun$process$1.apply$mcV$sp(Inbox.scala:116) at org.apache.spark.rpc.netty.Inbox.safelyCall(Inbox.scala:204) at org.apache.spark.rpc.netty.Inbox.process(Inbox.scala:100) at org.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) java.lang.IllegalStateException: Shutdown hooks cannot be modified during shutdown. at org.apache.spark.util.SparkShutdownHookManager.add(ShutdownHookManager.scala:246) at org.apache.spark.util.ShutdownHookManager$.addShutdownHook(ShutdownHookManager.scala:191) at org.apache.spark.util.ShutdownHookManager$.addShutdownHook(ShutdownHookManager.scala:180) at org.apache.spark.deploy.worker.ExecutorRunner.start(ExecutorRunner.scala:73) at org.apache.spark.deploy.worker.Worker$$anonfun$receive$1.applyOrElse(Worker.scala:474) at org.apache.spark.rpc.netty.Inbox$$anonfun$process$1.apply$mcV$sp(Inbox.scala:116) at org.apache.spark.rpc.netty.Inbox.safelyCall(Inbox.scala:204) at org.apache.spark.rpc.netty.Inbox.process(Inbox.scala:100) at org.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` Author: Shixiong Zhu Closes #10269 from zsxwing/executor-state. --- .../scala/org/apache/spark/deploy/LocalSparkCluster.scala | 2 ++ .../main/scala/org/apache/spark/deploy/master/Master.scala | 5 ++--- .../org/apache/spark/deploy/worker/ExecutorRunner.scala | 5 +++++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 83ccaadfe744..5bb62d37d637 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -75,6 +75,8 @@ class LocalSparkCluster( // Stop the workers before the master so they don't get upset that it disconnected workerRpcEnvs.foreach(_.shutdown()) masterRpcEnvs.foreach(_.shutdown()) + workerRpcEnvs.foreach(_.awaitTermination()) + masterRpcEnvs.foreach(_.awaitTermination()) masterRpcEnvs.clear() workerRpcEnvs.clear() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 04b20e0d6ab9..1355e1ad1b52 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -257,9 +257,8 @@ private[deploy] class Master( exec.state = state if (state == ExecutorState.RUNNING) { - if (oldState != ExecutorState.LAUNCHING) { - logWarning(s"Executor $execId state transfer from $oldState to RUNNING is unexpected") - } + assert(oldState == ExecutorState.LAUNCHING, + s"executor $execId state transfer from $oldState to RUNNING is illegal") appInfo.resetRetryCount() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 25a17473e4b5..9a42487bb37a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -71,6 +71,11 @@ private[deploy] class ExecutorRunner( workerThread.start() // Shutdown hook that kills actors on shutdown. shutdownHook = ShutdownHookManager.addShutdownHook { () => + // It's possible that we arrive here before calling `fetchAndRunExecutor`, then `state` will + // be `ExecutorState.RUNNING`. In this case, we should set `state` to `FAILED`. + if (state == ExecutorState.RUNNING) { + state = ExecutorState.FAILED + } killProcess(Some("Worker shutting down")) } } From 834e71489bf560302f9d743dff669df1134e9b74 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 13 Dec 2015 22:57:01 -0800 Subject: [PATCH 1756/1824] [SPARK-12213][SQL] use multiple partitions for single distinct query Currently, we could generate different plans for query with single distinct (depends on spark.sql.specializeSingleDistinctAggPlanning), one works better on low cardinality columns, the other works better for high cardinality column (default one). This PR change to generate a single plan (three aggregations and two exchanges), which work better in both cases, then we could safely remove the flag `spark.sql.specializeSingleDistinctAggPlanning` (introduced in 1.6). For a query like `SELECT COUNT(DISTINCT a) FROM table` will be ``` AGG-4 (count distinct) Shuffle to a single reducer Partial-AGG-3 (count distinct, no grouping) Partial-AGG-2 (grouping on a) Shuffle by a Partial-AGG-1 (grouping on a) ``` This PR also includes large refactor for aggregation (reduce 500+ lines of code) cc yhuai nongli marmbrus Author: Davies Liu Closes #10228 from davies/single_distinct. --- .../spark/sql/catalyst/CatalystConf.scala | 7 - .../DistinctAggregationRewriter.scala | 11 +- .../scala/org/apache/spark/sql/SQLConf.scala | 15 - .../aggregate/AggregationIterator.scala | 417 ++++++----------- .../aggregate/SortBasedAggregate.scala | 29 +- .../SortBasedAggregationIterator.scala | 47 +- .../aggregate/TungstenAggregate.scala | 25 +- .../TungstenAggregationIterator.scala | 439 +++--------------- .../spark/sql/execution/aggregate/utils.scala | 280 +++++------ .../execution/AggregationQuerySuite.scala | 142 +++--- 10 files changed, 422 insertions(+), 990 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 7c2b8a940788..2c7c58e66b85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean - - protected[spark] def specializeSingleDistinctAggPlanning: Boolean } /** @@ -31,13 +29,8 @@ object EmptyConf extends CatalystConf { override def caseSensitiveAnalysis: Boolean = { throw new UnsupportedOperationException } - - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = { - throw new UnsupportedOperationException - } } /** A CatalystConf that can be used for local testing. */ case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf { - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 9c78f6d4cc71..4e7d1341028c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -123,15 +123,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { - // When the flag is set to specialize single distinct agg planning, - // we will rely on our Aggregation strategy to handle queries with a single - // distinct column. - distinctAggGroups.size > 1 - } else { - distinctAggGroups.size >= 1 - } - if (shouldRewrite) { + // Aggregation strategy can handle the query with single distinct + if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = new AttributeReference("gid", IntegerType, false)() val groupByMap = a.groupingExpressions.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 58adf64e4986..3d819262859f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -449,18 +449,6 @@ private[spark] object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query" ) - val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = - booleanConf("spark.sql.specializeSingleDistinctAggPlanning", - defaultValue = Some(false), - isPublic = false, - doc = "When true, if a query only has a single distinct column and it has " + - "grouping expressions, we will use our planner rule to handle this distinct " + - "column (other cases are handled by DistinctAggregationRewriter). " + - "When false, we will always use DistinctAggregationRewriter to plan " + - "aggregation queries with DISTINCT keyword. This is an internal flag that is " + - "used to benchmark the performance impact of using DistinctAggregationRewriter to " + - "plan aggregation queries with a single distinct column.") - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -579,9 +567,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = - getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 008478a6a0e1..0c74df0aa5fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import scala.collection.mutable.ArrayBuffer - /** - * The base class of [[SortBasedAggregationIterator]]. + * The base class of [[SortBasedAggregationIterator]] and [[TungstenAggregationIterator]]. * It mainly contains two parts: * 1. It initializes aggregate functions. * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of @@ -33,64 +33,58 @@ import scala.collection.mutable.ArrayBuffer * is used to generate result. */ abstract class AggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + inputAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends Iterator[InternalRow] with Logging { + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection)) + extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// // Initializing functions. /////////////////////////////////////////////////////////////////////////// - // An Seq of all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - protected val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - require( - allAggregateExpressions.map(_.mode).distinct.length <= 2, - s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.") - /** - * The distinct modes of AggregateExpressions. Right now, we can handle the following mode: - * - Partial-only: all AggregateExpressions have the mode of Partial; - * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge); - * - Final-only: all AggregateExpressions have the mode of Final; - * - Final-Complete: some AggregateExpressions have the mode of Final and - * others have the mode of Complete; - * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions - * with mode Complete in completeAggregateExpressions; and - * - Grouping-only: there is no AggregateExpression. - */ - protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption + * The following combinations of AggregationMode are supported: + * - Partial + * - PartialMerge (for single distinct) + * - Partial and PartialMerge (for single distinct) + * - Final + * - Complete (for SortBasedAggregate with functions that does not support Partial) + * - Final and Complete (currently not used) + * + * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression + * could have a flag to tell it's final or not. + */ + { + val modes = aggregateExpressions.map(_.mode).distinct.toSet + require(modes.size <= 2, + s"$aggregateExpressions are not supported because they have more than 2 distinct modes.") + require(modes.subsetOf(Set(Partial, PartialMerge)) || modes.subsetOf(Set(Final, Complete)), + s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.") + } // Initialize all AggregateFunctions by binding references if necessary, // and set inputBufferOffset and mutableBufferOffset. - protected val allAggregateFunctions: Array[AggregateFunction] = { + protected def initializeAggregateFunctions( + expressions: Seq[AggregateExpression], + startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 - var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction](allAggregateExpressions.length) + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction](expressions.length) var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match { + while (i < expressions.length) { + val func = expressions(i).aggregateFunction + val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, valueAttributes) + BindReferences.bindReference(func, inputAttributes) case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. @@ -117,15 +111,18 @@ abstract class AggregationIterator( functions } + protected val aggregateFunctions: Array[AggregateFunction] = + initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset) + // Positions of those imperative aggregate functions in allAggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and // func2 and func3 are imperative aggregate functions. // ImperativeAggregateFunctionPositions will be [1, 2]. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { + protected[this] val allImperativeAggregateFunctionPositions: Array[Int] = { val positions = new ArrayBuffer[Int]() var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { case agg: DeclarativeAggregate => case _ => positions += i } @@ -134,17 +131,9 @@ abstract class AggregationIterator( positions.toArray } - // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - - // All imperative aggregate functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - // The projection used to initialize buffer values for all expression-based aggregates. - private[this] val expressionAggInitialProjection = { - val initExpressions = allAggregateFunctions.flatMap { + protected[this] val expressionAggInitialProjection = { + val initExpressions = aggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.initialValues // For the positions corresponding to imperative aggregate functions, we'll use special // no-op expressions which are ignored during projection code-generation. @@ -154,248 +143,112 @@ abstract class AggregationIterator( } // All imperative AggregateFunctions. - private[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + protected[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = allImperativeAggregateFunctionPositions - .map(allAggregateFunctions) + .map(aggregateFunctions) .map(_.asInstanceOf[ImperativeAggregate]) - /////////////////////////////////////////////////////////////////////////// - // Methods and fields used by sub-classes. - /////////////////////////////////////////////////////////////////////////// - // Initializing functions used to process a row. - protected val processRow: (MutableRow, InternalRow) => Unit = { - val rowToBeProcessed = new JoinedRow - val aggregationBufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) { - // If initialInputBufferOffset, the input value does not contain - // grouping keys. - // This part is pretty hacky. - allAggregateFunctions.flatMap(_.inputAggBufferAttributes).toSeq - } else { - groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - } - // val inputAggregationBufferSchema = - // groupingKeyAttributes ++ - // allAggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection( - mergeExpressions, - aggregationBufferSchema ++ inputAggregationBufferSchema)() - - (currentBuffer: MutableRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - // The first initialInputBufferOffset values of the input aggregation buffer is - // for grouping expressions and distinct columns. - val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset) - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - - val mergeInputSchema = - aggregationBufferSchema ++ - groupingAttributesAndDistinctColumns ++ - nonCompleteAggregateFunctions.flatMap(_.inputAggBufferAttributes) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalExpressionAggMergeProjection = - newMutableProjection(mergeExpressions, mergeInputSchema)() - - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffers. - finalExpressionAggMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 + protected def generateProcessRow( + expressions: Seq[AggregateExpression], + functions: Seq[AggregateFunction], + inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = { + val joinedRow = new JoinedRow + if (expressions.nonEmpty) { + val mergeExpressions = functions.zipWithIndex.flatMap { + case (ae: DeclarativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => ae.updateExpressions + case PartialMerge | Final => ae.mergeExpressions } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = - completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 + case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val updateFunctions = functions.zipWithIndex.collect { + case (ae: ImperativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => + (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row) + case PartialMerge | Final => + (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) } + } + // This projection is used to merge buffer values for all expression-based aggregates. + val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) + val updateProjection = + newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + // Process all expression-based aggregate functions. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions. + var i = 0 + while (i < updateFunctions.length) { + updateFunctions(i)(currentBuffer, row) + i += 1 } - + } + } else { // Grouping only. - case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {} - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + (currentBuffer: MutableRow, row: InternalRow) => {} } } - // Initializing the function used to generate the output row. - protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { - val rowToBeEvaluated = new JoinedRow - val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType)) - val mutableOutput = if (outputsUnsafeRows) { - UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow) - } else { - safeOutputRow - } - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not - // support generic getter), we create a mutable projection to output the - // JoinedRow(currentGroupingKey, currentBuffer) - val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.aggBufferAttributes) - val resultProjection = - newMutableProjection( - groupingKeyAttributes ++ bufferSchema, - groupingKeyAttributes ++ bufferSchema)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer)) - // rowToBeEvaluated(currentGroupingKey, currentBuffer) - } + protected val processRow: (MutableRow, InternalRow) => Unit = + generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes) - // Final-only, Complete-only and Final-Complete: every output row contains values representing - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val bufferSchemata = - allAggregateFunctions.flatMap(_.aggBufferAttributes) - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - // TODO: Use unsafe row. - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - newMutableProjection( - resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() - resultProjection.target(mutableOutput) + protected val groupingProjection: UnsafeProjection = + UnsafeProjection.create(groupingExpressions, inputAttributes) + protected val groupingAttributes = groupingExpressions.map(_.toAttribute) - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult)) + // Initializing the function used to generate the output row. + protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val joinedRow = new JoinedRow + val modes = aggregateExpressions.map(_.mode).distinct + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + if (modes.contains(Final) || modes.contains(Complete)) { + val evalExpressions = aggregateFunctions.map { + case ae: DeclarativeAggregate => ae.evaluateExpression + case agg: AggregateFunction => NoOp + } + val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + expressionAggEvalProjection.target(aggregateResult) + + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + // Generate results for all expression-based aggregate functions. + expressionAggEvalProjection(currentBuffer) + // Generate results for all imperative aggregate functions. + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + aggregateResult.update( + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) + i += 1 } - + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + val resultProjection = UnsafeProjection.create( + groupingAttributes ++ bufferAttributes, + groupingAttributes ++ bufferAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(joinedRow(currentGroupingKey, currentBuffer)) + } + } else { // Grouping-only: we only output values of grouping expressions. - case (None, None) => - val resultProjection = - newMutableProjection(resultExpressions, groupingKeyAttributes)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(currentGroupingKey) - } - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(currentGroupingKey) + } } } + protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow = + generateResultProjection() + /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(buffer: MutableRow): Unit = { expressionAggInitialProjection.target(buffer)(EmptyRow) @@ -405,10 +258,4 @@ abstract class AggregationIterator( i += 1 } } - - /** - * Creates a new aggregation buffer and initializes buffer values - * for all aggregate functions. - */ - protected def newBuffer: MutableRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ee982453c328..c5470a6989de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) @@ -42,10 +40,8 @@ case class SortBasedAggregate( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputsUnsafeRows: Boolean = false - + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -76,31 +72,24 @@ case class SortBasedAggregate( if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. - Iterator[InternalRow]() + Iterator[UnsafeRow]() } else { - val groupingKeyProjection = - UnsafeProjection.create(groupingExpressions, child.output) - val outputIter = new SortBasedAggregationIterator( - groupingKeyProjection, - groupingExpressions.map(_.toAttribute), + groupingExpressions, child.output, iter, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, newMutableProjection, - outputsUnsafeRows, numInputRows, numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. numOutputRows += 1 - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) } else { outputIter } @@ -109,7 +98,7 @@ case class SortBasedAggregate( } override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions val keyString = groupingExpressions.mkString("[", ",", "]") val functionString = allAggregateExpressions.mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index fe5c3195f867..ac920aa8bc7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -24,37 +24,34 @@ import org.apache.spark.sql.execution.metric.LongSQLMetric /** * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been - * sorted by values of [[groupingKeyAttributes]]. + * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( - groupingKeyProjection: InternalRow => InternalRow, - groupingKeyAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean, numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends AggregationIterator( - groupingKeyAttributes, + groupingExpressions, valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - override protected def newBuffer: MutableRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + newMutableProjection) { + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + private def newBuffer: MutableRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) @@ -76,10 +73,10 @@ class SortBasedAggregationIterator( /////////////////////////////////////////////////////////////////////////// // The partition key of the current partition. - private[this] var currentGroupingKey: InternalRow = _ + private[this] var currentGroupingKey: UnsafeRow = _ // The partition key of next partition. - private[this] var nextGroupingKey: InternalRow = _ + private[this] var nextGroupingKey: UnsafeRow = _ // The first row of next partition. private[this] var firstRowInNextGroup: InternalRow = _ @@ -94,7 +91,7 @@ class SortBasedAggregationIterator( if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) val inputRow = inputIterator.next() - nextGroupingKey = groupingKeyProjection(inputRow).copy() + nextGroupingKey = groupingProjection(inputRow).copy() firstRowInNextGroup = inputRow.copy() numInputRows += 1 sortedInputHasNewGroup = true @@ -120,7 +117,7 @@ class SortBasedAggregationIterator( while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. val currentRow = inputIterator.next() - val groupingKey = groupingKeyProjection(currentRow) + val groupingKey = groupingProjection(currentRow) numInputRows += 1 // Check if the current row belongs the current input row. @@ -146,7 +143,7 @@ class SortBasedAggregationIterator( override final def hasNext: Boolean = sortedInputHasNewGroup - override final def next(): InternalRow = { + override final def next(): UnsafeRow = { if (hasNext) { // Process the current group. processCurrentSortedGroup() @@ -162,8 +159,8 @@ class SortBasedAggregationIterator( } } - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { initializeBuffer(sortBasedAggregationBuffer) - generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 920de615e1d8..b8849c827048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -30,21 +30,18 @@ import org.apache.spark.sql.types.StructType case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { private[this] val aggregateBufferAttributes = { - (nonCompleteAggregateExpressions ++ completeAggregateExpressions) - .flatMap(_.aggregateFunction.aggBufferAttributes) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) } - require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes)) + require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), @@ -53,9 +50,7 @@ case class TungstenAggregate( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -94,10 +89,8 @@ case class TungstenAggregate( val aggregationIterator = new TungstenAggregationIterator( groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, newMutableProjection, @@ -119,7 +112,7 @@ case class TungstenAggregate( } override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions testFallbackStartsAt match { case None => @@ -135,9 +128,7 @@ case class TungstenAggregate( } object TungstenAggregate { - def supportsAggregate( - groupingExpressions: Seq[Expression], - aggregateBufferAttributes: Seq[Attribute]): Boolean = { + def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 04391443920a..582fdbe54706 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql.execution.aggregate -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{InternalAccumulator, Logging, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{InternalAccumulator, Logging, TaskContext} /** * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. @@ -63,15 +61,11 @@ import org.apache.spark.sql.types.StructType * * @param groupingExpressions * expressions for grouping keys - * @param nonCompleteAggregateExpressions + * @param aggregateExpressions * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], * [[PartialMerge]], or [[Final]]. - * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' + * @param aggregateAttributes the attributes of the aggregateExpressions' * outputs when they are stored in the final aggregation buffer. - * @param completeAggregateExpressions - * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]]. - * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs - * when they are stored in the final aggregation buffer. * @param resultExpressions * expressions for generating output rows. * @param newMutableProjection @@ -83,10 +77,8 @@ import org.apache.spark.sql.types.StructType */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), @@ -97,378 +89,62 @@ class TungstenAggregationIterator( numOutputRows: LongSQLMetric, dataSize: LongSQLMetric, spillSize: LongSQLMetric) - extends Iterator[UnsafeRow] with Logging { + extends AggregationIterator( + groupingExpressions, + originalInputAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) with Logging { /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// - // A Seq containing all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - private[this] val allAggregateExpressions: Seq[AggregateExpression] = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - // Check to make sure we do not have more than three modes in our AggregateExpressions. - // If we have, users are hitting a bug and we throw an IllegalStateException. - if (allAggregateExpressions.map(_.mode).distinct.length > 2) { - throw new IllegalStateException( - s"$allAggregateExpressions should have no more than 2 kinds of modes.") - } - // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled - // - // The modes of AggregateExpressions. Right now, we can handle the following mode: - // - Partial-only: - // All AggregateExpressions have the mode of Partial. - // For this case, aggregationMode is (Some(Partial), None). - // - PartialMerge-only: - // All AggregateExpressions have the mode of PartialMerge). - // For this case, aggregationMode is (Some(PartialMerge), None). - // - Final-only: - // All AggregateExpressions have the mode of Final. - // For this case, aggregationMode is (Some(Final), None). - // - Final-Complete: - // Some AggregateExpressions have the mode of Final and - // others have the mode of Complete. For this case, - // aggregationMode is (Some(Final), Some(Complete)). - // - Complete-only: - // nonCompleteAggregateExpressions is empty and we have AggregateExpressions - // with mode Complete in completeAggregateExpressions. For this case, - // aggregationMode is (None, Some(Complete)). - // - Grouping-only: - // There is no AggregateExpression. For this case, AggregationMode is (None,None). - // - private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = { - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption - } - - // Initialize all AggregateFunctions by binding references, if necessary, - // and setting inputBufferOffset and mutableBufferOffset. - private def initializeAllAggregateFunctions( - startingInputBufferOffset: Int): Array[AggregateFunction] = { - var mutableBufferOffset = 0 - var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction](allAggregateExpressions.length) - var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length - // We need to use this mode instead of func.mode in order to handle aggregation mode switching - // when switching to sort-based aggregation: - val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2 - val funcWithBoundReferences = mode match { - case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] => - // We need to create BoundReferences if the function is not an - // expression-based aggregate function (it does not support code-gen) and the mode of - // this function is Partial or Complete because we will call eval of this - // function's children in the update method of this aggregate function. - // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, originalInputAttributes) - case _ => - // We only need to set inputBufferOffset for aggregate functions with mode - // PartialMerge and Final. - val updatedFunc = func match { - case function: ImperativeAggregate => - function.withNewInputAggBufferOffset(inputBufferOffset) - case function => function - } - inputBufferOffset += func.aggBufferSchema.length - updatedFunc - } - val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { - case function: ImperativeAggregate => - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - function.withNewMutableAggBufferOffset(mutableBufferOffset) - case function => function - } - mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length - functions(i) = funcWithUpdatedAggBufferOffset - i += 1 - } - functions - } - - private[this] var allAggregateFunctions: Array[AggregateFunction] = - initializeAllAggregateFunctions(initialInputBufferOffset) - - // Positions of those imperative aggregate functions in allAggregateFunctions. - // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are imperative aggregate functions. Then - // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be - // updated when falling back to sort-based aggregation because the positions of the aggregate - // functions do not change in that case. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { - val positions = new ArrayBuffer[Int]() - var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { - case agg: DeclarativeAggregate => - case _ => positions += i - } - i += 1 - } - positions.toArray - } - /////////////////////////////////////////////////////////////////////////// // Part 2: Methods and fields used by setting aggregation buffer values, // processing input rows from inputIter, and generating output // rows. /////////////////////////////////////////////////////////////////////////// - // The projection used to initialize buffer values for all expression-based aggregates. - // Note that this projection does not need to be updated when switching to sort-based aggregation - // because the schema of empty aggregation buffers does not change in that case. - private[this] val expressionAggInitialProjection: MutableProjection = { - val initExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.initialValues - // For the positions corresponding to imperative aggregate functions, we'll use special - // no-op expressions which are ignored during projection code-generation. - case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) - } - newMutableProjection(initExpressions, Nil)() - } - // Creates a new aggregation buffer and initializes buffer values. - // This function should be only called at most three times (when we create the hash map, - // when we switch to sort-based aggregation, and when we create the re-used buffer for - // sort-based aggregation). + // This function should be only called at most two times (when we create the hash map, + // and when we create the re-used buffer for sort-based aggregation). private def createNewAggregationBuffer(): UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) .apply(new GenericMutableRow(bufferSchema.length)) // Initialize declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initialize imperative aggregates' buffer values - allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) buffer } - // Creates a function used to process a row based on the given inputAttributes. - private def generateProcessRow( - inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = { - - val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - val joinedRow = new JoinedRow() - - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val mergeExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val nonCompleteAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - val input = joinedRow(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffer values in row to - // currentBuffer. - finalMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // Grouping only. - case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {} - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") - } - } - // Creates a function used to generate output rows. - private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { - - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val bufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - unsafeRowJoiner.join(currentGroupingKey, currentBuffer) - } - - // Final-only, Complete-only and Final-Complete: a output row is generated based on - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val joinedRow = new JoinedRow() - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() - // These are the attributes of the row produced by `expressionAggEvalProjection` - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) - - val allImperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } - - // Grouping-only: a output row is generated from values of grouping expressions. - case (None, None) => - val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(currentGroupingKey) - } - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) { + // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow]) + } + } else { + super.generateResultProjection() } } - // An UnsafeProjection used to extract grouping keys from the input rows. - private[this] val groupProjection = - UnsafeProjection.create(groupingExpressions, originalInputAttributes) - - // A function used to process a input row. Its first argument is the aggregation buffer - // and the second argument is the input row. - private[this] var processRow: (UnsafeRow, InternalRow) => Unit = - generateProcessRow(originalInputAttributes) - - // A function used to generate output rows based on the grouping keys (first argument) - // and the corresponding aggregation buffer (second argument). - private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow = - generateResultProjection() - // An aggregation buffer containing initial buffer values. It is used to // initialize other aggregation buffers. private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() @@ -482,7 +158,7 @@ class TungstenAggregationIterator( // all groups and their corresponding aggregation buffers for hash-based aggregation. private[this] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)), + StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), TaskContext.get().taskMemoryManager(), 1024 * 16, // initial capacity @@ -499,7 +175,7 @@ class TungstenAggregationIterator( if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. - val groupingKey = groupProjection.apply(null) + val groupingKey = groupingProjection.apply(null) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) while (inputIter.hasNext) { val newInput = inputIter.next() @@ -511,7 +187,7 @@ class TungstenAggregationIterator( while (inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) + val groupingKey = groupingProjection.apply(newInput) var buffer: UnsafeRow = null if (i < fallbackStartsAt) { buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) @@ -565,25 +241,18 @@ class TungstenAggregationIterator( private def switchToSortBasedAggregation(): Unit = { logInfo("falling back to sort based aggregation.") - // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. - val newAggregationMode = aggregationMode match { - case (Some(Partial), None) => (Some(PartialMerge), None) - case (None, Some(Complete)) => (Some(Final), None) - case (Some(Final), Some(Complete)) => (Some(Final), None) + // Basically the value of the KVIterator returned by externalSorter + // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. + val newExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, Partial, _) => + agg.copy(mode = PartialMerge) + case agg @ AggregateExpression(_, Complete, _) => + agg.copy(mode = Final) case other => other } - aggregationMode = newAggregationMode - - allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0) - - // Basically the value of the KVIterator returned by externalSorter - // will just aggregation buffer. At here, we use inputAggBufferAttributes. - val newInputAttributes: Seq[Attribute] = - allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - - // Set up new processRow and generateOutput. - processRow = generateProcessRow(newInputAttributes) - generateOutput = generateResultProjection() + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes) // Step 5: Get the sorted iterator from the externalSorter. sortedKVIterator = externalSorter.sortedIterator() @@ -632,6 +301,9 @@ class TungstenAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + // The function used to process rows in a group + private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null + // Processes rows in the current group. It will stop when it find a new group. private def processCurrentSortedGroup(): Unit = { // First, we need to copy nextGroupingKey to currentGroupingKey. @@ -640,7 +312,7 @@ class TungstenAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -655,16 +327,15 @@ class TungstenAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey.equals(groupingKey)) { - processRow(sortBasedAggregationBuffer, inputAggregationBuffer) + sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer) hasNext = sortedKVIterator.next() } else { // We find a new group. findNextPartition = true // copyFrom will fail when - nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy() - firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy() - + nextGroupingKey.copyFrom(groupingKey) + firstRowInNextGroup.copyFrom(inputAggregationBuffer) } } // We have not seen a new group. It means that there is no new row in the input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 76b938cdb694..83379ae90f70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -42,16 +42,45 @@ object Utils { SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = Nil, - nonCompleteAggregateAttributes = Nil, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = resultExpressions, child = child ) :: Nil } + private def createAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -59,9 +88,6 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. - val usesTungstenAggregate = TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. @@ -73,29 +99,14 @@ object Utils { groupingAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + val partialAggregate = createAggregate( + requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialResultExpressions, - child = child) - } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -105,29 +116,14 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val finalAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = groupingExpressions.length, - resultExpressions = resultExpressions, - child = partialAggregate) - } else { - SortBasedAggregate( + val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, child = partialAggregate) - } finalAggregate :: Nil } @@ -140,99 +136,99 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct - val usesTungstenAggregate = TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is // disallowed because those two distinct aggregates have different column expressions. - val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children - val namedDistinctColumnExpressions = distinctColumnExpressions.map { + val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctExpressions = distinctExpressions.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute) + val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { - val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. - val partialAggregateGroupingExpressions = - groupingExpressions ++ namedDistinctColumnExpressions - val partialAggregateResult = - groupingAttributes ++ - distinctColumnAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } + createAggregate( + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) } // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { - val partialMergeAggregateExpressions = - functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialMergeAggregateResult = - groupingAttributes ++ - distinctColumnAttributes ++ - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes ++ distinctAttributes), + groupingExpressions = groupingAttributes ++ distinctAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + // 3. Create an Aggregate operator for partial aggregation (for distinct) + val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap + val rewrittenDistinctFunctions = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression(aggregateFunction, mode, true) => + aggregateFunction.transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] } - // 3. Create an Aggregate Operator for the final aggregation. + val partialDistinctAggregate: SparkPlan = { + val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val mergeAggregateAttributes = mergeAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Partial, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) + }.unzip + + val partialAggregateResult = groupingAttributes ++ + mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ + distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + createAggregate( + groupingExpressions = groupingAttributes, + aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = partialAggregateResult, + child = partialMergeAggregate) + } + + // 4. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -241,49 +237,27 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val distinctColumnAttributeLookup = - distinctColumnExpressions.zip(distinctColumnAttributes).toMap - val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { - // Children of an AggregateFunction with DISTINCT keyword has already - // been evaluated. At here, we need to replace original children - // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction - .transformDown(distinctColumnAttributeLookup) - .asInstanceOf[AggregateFunction] + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val rewrittenAggregateExpression = - AggregateExpression(rewrittenAggregateFunction, Complete, isDistinct = true) - - val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) - (rewrittenAggregateExpression, aggregateFunctionAttribute) + val expr = AggregateExpression(func, Final, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) }.unzip - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = partialDistinctAggregate) } finalAndCompleteAggregate :: Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 064c0004b801..5550198c02fb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ -import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} @@ -552,80 +551,73 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("single distinct column set") { - Seq(true, false).foreach { specializeSingleDistinctAgg => - val conf = - (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key, - specializeSingleDistinctAgg.toString) - withSQLConf(conf) { - // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. - checkAnswer( - sqlContext.sql( - """ - |SELECT - | min(distinct value1), - | sum(distinct value1), - | avg(value1), - | avg(value2), - | max(distinct value1) - |FROM agg2 - """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100)) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | mydoubleavg(distinct value1), - | avg(value1), - | avg(value2), - | key, - | mydoubleavg(value1 - 1), - | mydoubleavg(distinct value1) * 0.1, - | avg(value1 + value2) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: - Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: - Row(null, null, 3.0, 3, null, null, null) :: - Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | key, - | mydoubleavg(distinct value1), - | mydoublesum(value2), - | mydoublesum(distinct value1), - | mydoubleavg(distinct value1), - | mydoubleavg(value1) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: - Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: - Row(3, null, 3.0, null, null, null) :: - Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | count(value1), - | count(*), - | count(1), - | count(DISTINCT value1), - | key - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(3, 3, 3, 2, 1) :: - Row(3, 4, 4, 2, 2) :: - Row(0, 2, 2, 0, 3) :: - Row(3, 4, 4, 3, null) :: Nil) - } - } + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) } test("single distinct multiple columns set") { From ed87f6d3b48a85391628c29c43d318c26e2c6de7 Mon Sep 17 00:00:00 2001 From: yucai Date: Sun, 13 Dec 2015 23:08:21 -0800 Subject: [PATCH 1757/1824] [SPARK-12275][SQL] No plan for BroadcastHint in some condition When SparkStrategies.BasicOperators's "case BroadcastHint(child) => apply(child)" is hit, it only recursively invokes BasicOperators.apply with this "child". It makes many strategies have no change to process this plan, which probably leads to "No plan" issue, so we use planLater to go through all strategies. https://issues.apache.org/jira/browse/SPARK-12275 Author: yucai Closes #10265 from yucai/broadcast_hint. --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameJoinSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 25e98c0bdd43..688555cf136e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -364,7 +364,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil - case BroadcastHint(child) => apply(child) + case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 56ad71ea4f48..c70397f9853a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -120,5 +120,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // planner should not crash without a join broadcast(df1).queryExecution.executedPlan + + // SPARK-12275: no physical plan for BroadcastHint in some condition + withTempPath { path => + df1.write.parquet(path.getCanonicalPath) + val pf1 = sqlContext.read.parquet(path.getCanonicalPath) + assert(df1.join(broadcast(pf1)).count() === 4) + } } } From e25f1fe42747be71c6b6e6357ca214f9544e3a46 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 14 Dec 2015 13:50:30 +0000 Subject: [PATCH 1758/1824] [MINOR][DOC] Fix broken word2vec link Follow-up of [SPARK-12199](https://issues.apache.org/jira/browse/SPARK-12199) and #10193 where a broken link has been left as is. Author: BenFradet Closes #10282 from BenFradet/SPARK-12199. --- docs/ml-features.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 158f3f201899..677e4bfb916e 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -63,7 +63,7 @@ the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for mor `Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` transforms each document into a vector using the average of all words in the document; this vector can then be used for as features for prediction, document similarity calculations, etc. -Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2Vec) for more +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2vec) for more details. In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. From b51a4cdff3a7e640a8a66f7a9c17021f3056fd34 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Dec 2015 09:59:42 -0800 Subject: [PATCH 1759/1824] [SPARK-12016] [MLLIB] [PYSPARK] Wrap Word2VecModel when loading it in pyspark JIRA: https://issues.apache.org/jira/browse/SPARK-12016 We should not directly use Word2VecModel in pyspark. We need to wrap it in a Word2VecModelWrapper when loading it in pyspark. Author: Liang-Chi Hsieh Closes #10100 from viirya/fix-load-py-wordvecmodel. --- .../mllib/api/python/PythonMLLibAPI.scala | 33 ---------- .../api/python/Word2VecModelWrapper.scala | 62 +++++++++++++++++++ python/pyspark/mllib/feature.py | 6 +- 3 files changed, 67 insertions(+), 34 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 8d546e3d6099..29160a10e16b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -680,39 +680,6 @@ private[python] class PythonMLLibAPI extends Serializable { } } - private[python] class Word2VecModelWrapper(model: Word2VecModel) { - def transform(word: String): Vector = { - model.transform(word) - } - - /** - * Transforms an RDD of words to its vector representation - * @param rdd an RDD of words - * @return an RDD of vector representations of words - */ - def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { - rdd.rdd.map(model.transform) - } - - def findSynonyms(word: String, num: Int): JList[Object] = { - val vec = transform(word) - findSynonyms(vec, num) - } - - def findSynonyms(vector: Vector, num: Int): JList[Object] = { - val result = model.findSynonyms(vector, num) - val similarity = Vectors.dense(result.map(_._2)) - val words = result.map(_._1) - List(words, similarity).map(_.asInstanceOf[Object]).asJava - } - - def getVectors: JMap[String, JList[Float]] = { - model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava - } - - def save(sc: SparkContext, path: String): Unit = model.save(sc, path) - } - /** * Java stub for Python mllib DecisionTree.train(). * This stub returns a handle to the Java object instead of the content of the Java object. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala new file mode 100644 index 000000000000..0f55980481dc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.feature.Word2VecModel +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * Wrapper around Word2VecModel to provide helper methods in Python + */ +private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + /** + * Transforms an RDD of words to its vector representation + * @param rdd an RDD of words + * @return an RDD of vector representations of words + */ + def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { + rdd.rdd.map(model.transform) + } + + def findSynonyms(word: String, num: Int): JList[Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): JList[Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + List(words, similarity).map(_.asInstanceOf[Object]).asJava + } + + def getVectors: JMap[String, JList[Float]] = { + model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 7b077b058c3f..7254679ebb53 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -504,7 +504,8 @@ def load(cls, sc, path): """ jmodel = sc._jvm.org.apache.spark.mllib.feature \ .Word2VecModel.load(sc._jsc.sc(), path) - return Word2VecModel(jmodel) + model = sc._jvm.Word2VecModelWrapper(jmodel) + return Word2VecModel(model) @ignore_unicode_prefix @@ -546,6 +547,9 @@ class Word2Vec(object): >>> sameModel = Word2VecModel.load(sc, path) >>> model.transform("a") == sameModel.transform("a") True + >>> syms = sameModel.findSynonyms("a", 2) + >>> [s[0] for s in syms] + [u'b', u'c'] >>> from shutil import rmtree >>> try: ... rmtree(path) From fb3778de685881df66bf0222b520f94dca99e8c8 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 14 Dec 2015 16:13:55 -0800 Subject: [PATCH 1760/1824] [SPARK-12327] Disable commented code lintr temporarily cc yhuai felixcheung shaneknapp Author: Shivaram Venkataraman Closes #10300 from shivaram/comment-lintr-disable. --- R/pkg/.lintr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 038236fc149e..39c872663ad4 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE), commented_code_linter = NULL) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") From 9ea1a8efca7869618c192546ef4e6d94c17689b5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 14 Dec 2015 16:48:11 -0800 Subject: [PATCH 1761/1824] [SPARK-12274][SQL] WrapOption should not have type constraint for child I think it was a mistake, and we have not catched it so far until https://github.com/apache/spark/pull/10260 which begin to check if the `fromRowExpression` is resolved. Author: Wenchen Fan Closes #10263 from cloud-fan/encoder. --- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index b2facfda2444..96bc4fe67a98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -296,15 +296,12 @@ case class UnwrapOption( * (in the case of reference types) equality with null. * @param child The expression to evaluate and wrap. */ -case class WrapOption(child: Expression) - extends UnaryExpression with ExpectsInputTypes { +case class WrapOption(child: Expression) extends UnaryExpression { override def dataType: DataType = ObjectType(classOf[Option[_]]) override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") From d13ff82cba10a1ce9889c4b416f1e43c717e3f10 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 14 Dec 2015 18:33:45 -0800 Subject: [PATCH 1762/1824] [SPARK-12188][SQL][FOLLOW-UP] Code refactoring and comment correction in Dataset APIs marmbrus This PR is to address your comment. Thanks for your review! Author: gatorsmile Closes #10214 from gatorsmile/followup12188. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3bd18a14f9e8..dc69822e9290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -79,7 +79,7 @@ class Dataset[T] private[sql]( /** * The encoder where the expressions used to construct an object from an input row have been - * bound to the ordinals of the given schema. + * bound to the ordinals of this [[Dataset]]'s output schema. */ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) From 606f99b942a25dc3d86138b00026462ffe58cded Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 14 Dec 2015 19:42:16 -0800 Subject: [PATCH 1763/1824] [SPARK-12288] [SQL] Support UnsafeRow in Coalesce/Except/Intersect. Support UnsafeRow for the Coalesce/Except/Intersect. Could you review if my code changes are ok? davies Thank you! Author: gatorsmile Closes #10285 from gatorsmile/unsafeSupportCIE. --- .../spark/sql/execution/basicOperators.scala | 12 ++++++- .../execution/RowFormatConvertersSuite.scala | 35 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index a42aea0b96d4..b3e4688557ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -137,7 +137,7 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } } } - override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = @@ -250,7 +250,9 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { child.execute().coalesce(numPartitions, shuffle = false) } + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** @@ -263,6 +265,10 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** @@ -275,6 +281,10 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } + + override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 13d68a103a22..2328899bb2f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -58,6 +58,41 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { assert(!preparedPlan.outputsUnsafeRows) } + test("coalesce can process unsafe rows") { + val plan = Coalesce(1, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 1) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except can process unsafe rows") { + val plan = Except(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("except requires all of its input rows' formats to agree") { + val plan = Except(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect can process unsafe rows") { + val plan = Intersect(outputsUnsafe, outputsUnsafe) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(getConverters(preparedPlan).size === 2) + assert(preparedPlan.outputsUnsafeRows) + } + + test("intersect requires all of its input rows' formats to agree") { + val plan = Intersect(outputsSafe, outputsUnsafe) + assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) + assert(preparedPlan.outputsUnsafeRows) + } + test("execute() fails an assertion if inputs rows are of different formats") { val e = intercept[AssertionError] { Union(Seq(outputsSafe, outputsUnsafe)).execute() From c59df8c51609a0d6561ae1868e7970b516fb1811 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 15 Dec 2015 11:38:57 +0000 Subject: [PATCH 1764/1824] [SPARK-12332][TRIVIAL][TEST] Fix minor typo in ResetSystemProperties Fix a minor typo (unbalanced bracket) in ResetSystemProperties. Author: Holden Karau Closes #10303 from holdenk/SPARK-12332-trivial-typo-in-ResetSystemProperties-comment. --- .../scala/org/apache/spark/util/ResetSystemProperties.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala index c58db5e606f7..60fb7abb66d3 100644 --- a/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala +++ b/core/src/test/scala/org/apache/spark/util/ResetSystemProperties.scala @@ -45,7 +45,7 @@ private[spark] trait ResetSystemProperties extends BeforeAndAfterEach { this: Su var oldProperties: Properties = null override def beforeEach(): Unit = { - // we need SerializationUtils.clone instead of `new Properties(System.getProperties()` because + // we need SerializationUtils.clone instead of `new Properties(System.getProperties())` because // the later way of creating a copy does not copy the properties but it initializes a new // Properties object with the given properties as defaults. They are not recognized at all // by standard Scala wrapper over Java Properties then. From bc1ff9f4a41401599d3a87fb3c23a2078228a29b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 15 Dec 2015 09:41:40 -0800 Subject: [PATCH 1765/1824] [STREAMING][MINOR] Fix typo in function name of StateImpl cc\ tdas zsxwing , please review. Thanks a lot. Author: jerryshao Closes #10305 from jerryshao/fix-typo-state-impl. --- streaming/src/main/scala/org/apache/spark/streaming/State.scala | 2 +- .../scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala | 2 +- .../scala/org/apache/spark/streaming/MapWithStateSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index b47bdda2c213..42424d67d883 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -206,7 +206,7 @@ private[streaming] class StateImpl[S] extends State[S] { * Update the internal data and flags in `this` to the given state that is going to be timed out. * This method allows `this` object to be reused across many state records. */ - def wrapTiminoutState(newState: S): Unit = { + def wrapTimingOutState(newState: S): Unit = { this.state = newState defined = true timingOut = true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index ed95171f73ee..fdf61674a37f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -67,7 +67,7 @@ private[streaming] object MapWithStateRDDRecord { // data returned if (removeTimedoutData && timeoutThresholdTime.isDefined) { newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => - wrappedState.wrapTiminoutState(state) + wrappedState.wrapTimingOutState(state) val returned = mappingFunction(batchTime, key, None, wrappedState) mappedData ++= returned newStateMap.remove(key) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 4b08085e09b1..6b21433f1781 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -125,7 +125,7 @@ class MapWithStateSuite extends SparkFunSuite state.remove() testState(None, shouldBeRemoved = true) - state.wrapTiminoutState(3) + state.wrapTimingOutState(3) testState(Some(3), shouldBeTimingOut = true) } From b24c12d7338b47b637698e7458ba90f34cba28c0 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 15 Dec 2015 16:29:39 -0800 Subject: [PATCH 1766/1824] [MINOR][ML] Rename weights to coefficients for examples/DeveloperApiExample Rename ```weights``` to ```coefficients``` for examples/DeveloperApiExample. cc mengxr jkbradley Author: Yanbo Liang Closes #10280 from yanboliang/spark-coefficients. --- .../examples/ml/JavaDeveloperApiExample.java | 22 +++++++++---------- .../examples/ml/DeveloperApiExample.scala | 16 +++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 0b4c0d9ba9f8..b9dd3ad95771 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -89,7 +89,7 @@ public static void main(String[] args) throws Exception { } if (sumPredictions != 0.0) { throw new Exception("MyJavaLogisticRegression predicted something other than 0," + - " even though all weights are 0!"); + " even though all coefficients are 0!"); } jsc.stop(); @@ -149,12 +149,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { // Extract columns from data using helper method. JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. int numFeatures = oldDataset.take(1).get(0).features().size(); - Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + Vector coefficients = Vectors.zeros(numFeatures); // Learning would happen here. // Create a model, and return it. - return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); + return new MyJavaLogisticRegressionModel(uid(), coefficients).setParent(this); } @Override @@ -173,12 +173,12 @@ public MyJavaLogisticRegression copy(ParamMap extra) { class MyJavaLogisticRegressionModel extends ClassificationModel { - private Vector weights_; - public Vector weights() { return weights_; } + private Vector coefficients_; + public Vector coefficients() { return coefficients_; } - public MyJavaLogisticRegressionModel(String uid, Vector weights) { + public MyJavaLogisticRegressionModel(String uid, Vector coefficients) { this.uid_ = uid; - this.weights_ = weights; + this.coefficients_ = coefficients; } private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); @@ -208,7 +208,7 @@ public String uid() { * modifier. */ public Vector predictRaw(Vector features) { - double margin = BLAS.dot(features, weights_); + double margin = BLAS.dot(features, coefficients_); // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). return Vectors.dense(-margin, margin); @@ -222,7 +222,7 @@ public Vector predictRaw(Vector features) { /** * Number of features the model was trained on. */ - public int numFeatures() { return weights_.size(); } + public int numFeatures() { return coefficients_.size(); } /** * Create a copy of the model. @@ -235,7 +235,7 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + return copyValues(new MyJavaLogisticRegressionModel(uid(), coefficients_), extra) .setParent(parent()); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 3758edc56198..c1f63c6a1dce 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -75,7 +75,7 @@ object DeveloperApiExample { prediction }.sum assert(sumPredictions == 0.0, - "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + "MyLogisticRegression predicted something other than 0, even though all coefficients are 0!") sc.stop() } @@ -124,12 +124,12 @@ private class MyLogisticRegression(override val uid: String) // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset) - // Do learning to estimate the weight vector. + // Do learning to estimate the coefficients vector. val numFeatures = oldDataset.take(1)(0).features.size - val weights = Vectors.zeros(numFeatures) // Learning would happen here. + val coefficients = Vectors.zeros(numFeatures) // Learning would happen here. // Create a model, and return it. - new MyLogisticRegressionModel(uid, weights).setParent(this) + new MyLogisticRegressionModel(uid, coefficients).setParent(this) } override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) @@ -142,7 +142,7 @@ private class MyLogisticRegression(override val uid: String) */ private class MyLogisticRegressionModel( override val uid: String, - val weights: Vector) + val coefficients: Vector) extends ClassificationModel[Vector, MyLogisticRegressionModel] with MyLogisticRegressionParams { @@ -163,7 +163,7 @@ private class MyLogisticRegressionModel( * confidence for that label. */ override protected def predictRaw(features: Vector): Vector = { - val margin = BLAS.dot(features, weights) + val margin = BLAS.dot(features, coefficients) // There are 2 classes (binary classification), so we return a length-2 vector, // where index i corresponds to class i (i = 0, 1). Vectors.dense(-margin, margin) @@ -173,7 +173,7 @@ private class MyLogisticRegressionModel( override val numClasses: Int = 2 /** Number of features the model was trained on. */ - override val numFeatures: Int = weights.size + override val numFeatures: Int = coefficients.size /** * Create a copy of the model. @@ -182,7 +182,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) + copyValues(new MyLogisticRegressionModel(uid, coefficients), extra).setParent(parent) } } // scalastyle:on println From 86ea64dd146757c8f997d05fb5bb44f6aa58515c Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 15 Dec 2015 16:55:58 -0800 Subject: [PATCH 1767/1824] [SPARK-12271][SQL] Improve error message when Dataset.as[ ] has incompatible schemas. Author: Nong Li Closes #10260 from nongli/spark-11271. --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/encoders/ExpressionEncoder.scala | 1 + .../spark/sql/catalyst/expressions/objects.scala | 12 +++++++----- .../scala/org/apache/spark/sql/DatasetSuite.scala | 10 +++++++++- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9013fd050b5f..ecff8605706d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -184,7 +184,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath - WrapOption(constructorFor(optType, path, newTypePath)) + WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 3e8420ecb9cc..363178b0e21a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -251,6 +251,7 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) + SimpleAnalyzer.checkAnalysis(analyzedPlan) val optimizedPlan = SimplifyCasts(analyzedPlan) // In order to construct instances of inner classes (for example those declared in a REPL cell), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 96bc4fe67a98..10ec75eca37f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -23,11 +23,9 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} -import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** @@ -295,13 +293,17 @@ case class UnwrapOption( * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. * @param child The expression to evaluate and wrap. + * @param optType The type of this option. */ -case class WrapOption(child: Expression) extends UnaryExpression { +case class WrapOption(child: Expression, optType: DataType) + extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) override def nullable: Boolean = true + override def inputTypes: Seq[AbstractDataType] = optType :: Nil + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 542e4d6c43b9..8f8db318261d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -481,10 +481,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) } -} + test("verify mismatching field names fail with a good error") { + val ds = Seq(ClassData("a", 1)).toDS() + val e = intercept[AnalysisException] { + ds.as[ClassData2].collect() + } + assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) + } +} case class ClassData(a: String, b: Int) +case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) /** From 28112657ea5919451291c21b4b8e1eb3db0ec8d4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 17:02:14 -0800 Subject: [PATCH 1768/1824] [SPARK-12236][SQL] JDBC filter tests all pass if filters are not really pushed down https://issues.apache.org/jira/browse/SPARK-12236 Currently JDBC filters are not tested properly. All the tests pass even if the filters are not pushed down due to Spark-side filtering. In this PR, Firstly, I corrected the tests to properly check the pushed down filters by removing Spark-side filtering. Also, `!=` was being tested which is actually not pushed down. So I removed them. Lastly, I moved the `stripSparkFilter()` function to `SQLTestUtils` as this functions would be shared for all tests for pushed down filters. This function would be also shared with ORC datasource as the filters for that are also not being tested properly. Author: hyukjinkwon Closes #10221 from HyukjinKwon/SPARK-12236. --- .../datasources/parquet/ParquetFilterSuite.scala | 15 --------------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 10 ++++------ .../org/apache/spark/sql/test/SQLTestUtils.scala | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index daf41bc292cc..6178e37d2a58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -110,21 +110,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } - /** - * Strip Spark-side filtering in order to check if a datasource filters rows correctly. - */ - protected def stripSparkFilter(df: DataFrame): DataFrame = { - val schema = df.schema - val childRDD = df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - .map(row => Row.fromSeq(row.toSeq(schema))) - - sqlContext.createDataFrame(childRDD, schema) - } - test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 8c24aa3151bc..a36094715299 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -176,12 +176,10 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT * WHERE (simple predicates)") { - assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size === 0) - assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME = 'fred'").collect().size === 1) - assert(sql("SELECT * FROM foobar WHERE NAME > 'fred'").collect().size === 2) - assert(sql("SELECT * FROM foobar WHERE NAME != 'fred'").collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size === 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size === 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) } test("SELECT * WHERE (quoted strings)") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 9214569f18e9..e87da1527c4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -179,6 +179,21 @@ private[sql] trait SQLTestUtils try f finally sqlContext.sql(s"USE default") } + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val childRDD = df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + .map(row => Row.fromSeq(row.toSeq(schema))) + + sqlContext.createDataFrame(childRDD, schema) + } + /** * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier * way to construct [[DataFrame]] directly out of local data without relying on implicits. From 31b391019ff6eb5a483f4b3e62fd082de7ff8416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Tue, 15 Dec 2015 18:06:30 -0800 Subject: [PATCH 1769/1824] [SPARK-12105] [SQL] add convenient show functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: Jean-Baptiste Onofré Closes #10130 from jbonofre/SPARK-12105. --- .../org/apache/spark/sql/DataFrame.scala | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 497bd4826677..b69d4411425d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -160,17 +160,24 @@ class DataFrame private[sql]( } } + /** + * Compose the string representing rows for output + */ + def showString(): String = { + showString(20) + } + /** * Compose the string representing rows for output - * @param _numRows Number of rows to show + * @param numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { - val numRows = _numRows.max(0) + def showString(numRows: Int, truncate: Boolean = true): String = { + val _numRows = numRows.max(0) val sb = new StringBuilder - val takeResult = take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + val takeResult = take(_numRows + 1) + val hasMoreData = takeResult.length > _numRows + val data = takeResult.take(_numRows) val numCols = schema.fieldNames.length // For array values, replace Seq and Array with square brackets @@ -224,10 +231,10 @@ class DataFrame private[sql]( sb.append(sep) - // For Data that has more than "numRows" records + // For Data that has more than "_numRows" records if (hasMoreData) { - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") + val rowsString = if (_numRows == 1) "row" else "rows" + sb.append(s"only showing top $_numRows $rowsString\n") } sb.toString() From 840bd2e008da5b22bfa73c587ea2c57666fffc60 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Dec 2015 18:11:53 -0800 Subject: [PATCH 1770/1824] [HOTFIX] Compile error from commit 31b3910 --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index b69d4411425d..33b03be1138b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -234,7 +234,7 @@ class DataFrame private[sql]( // For Data that has more than "_numRows" records if (hasMoreData) { val rowsString = if (_numRows == 1) "row" else "rows" - sb.append(s"only showing top $_numRows $rowsString\n") + sb.append(s"only showing top ${_numRows} $rowsString\n") } sb.toString() From f725b2ec1ab0d89e35b5e2d3ddeddb79fec85f6d Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 15 Dec 2015 18:15:10 -0800 Subject: [PATCH 1771/1824] [SPARK-12056][CORE] Part 2 Create a TaskAttemptContext only after calling setConf This is continuation of SPARK-12056 where change is applied to SqlNewHadoopRDD.scala andrewor14 FYI Author: tedyu Closes #10164 from tedyu/master. --- .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 56cb63d9eff2..eea780cbaa7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -148,14 +148,14 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } inputMetrics.setBytesReadCallback(bytesReadCallback) - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) private[this] var reader: RecordReader[Void, V] = null /** From 369127f03257e7081d2aa1fc445e773b26f0d5e3 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 15 Dec 2015 18:16:22 -0800 Subject: [PATCH 1772/1824] [SPARK-12130] Replace shuffleManagerClass with shortShuffleMgrNames in ExternalShuffleBlockResolver Replace shuffleManagerClassName with shortShuffleMgrName is to reduce time of string's comparison. and put sort's comparison on the front. cc JoshRosen andrewor14 Author: Lianhui Wang Closes #10131 from lianhuiwang/spark-12130. --- .../scala/org/apache/spark/shuffle/ShuffleManager.scala | 4 ++++ .../org/apache/spark/shuffle/hash/HashShuffleManager.scala | 2 ++ .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 2 ++ .../main/scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../network/shuffle/ExternalShuffleBlockResolver.java | 7 +++---- .../apache/spark/network/sasl/SaslIntegrationSuite.java | 3 +-- .../network/shuffle/ExternalShuffleBlockResolverSuite.java | 6 +++--- .../network/shuffle/ExternalShuffleIntegrationSuite.java | 4 ++-- 8 files changed, 18 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 978366d1a1d1..a3444bf4daa3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -28,6 +28,10 @@ import org.apache.spark.{TaskContext, ShuffleDependency} * boolean isDriver as parameters. */ private[spark] trait ShuffleManager { + + /** Return short name for the ShuffleManager */ + val shortName: String + /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index d2e2fc4c110a..4f30da0878ee 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -34,6 +34,8 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) + override val shortName: String = "hash" + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 66b6bbc61fe8..9b1a27952842 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -79,6 +79,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + override val shortName: String = "sort" + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ed05143877e2..540e1ec003a2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -200,7 +200,7 @@ private[spark] class BlockManager( val shuffleConfig = new ExecutorShuffleInfo( diskBlockManager.localDirs.map(_.toString), diskBlockManager.subDirsPerLocalDir, - shuffleManager.getClass.getName) + shuffleManager.shortName) val MAX_ATTEMPTS = 3 val SLEEP_TIME_SECS = 5 diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index e5cb68c8a4db..fe933ed650ca 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -183,11 +183,10 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { - return getHashBasedShuffleBlockData(executor, blockId); - } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager) - || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) { + if ("sort".equals(executor.shuffleManager) || "tungsten-sort".equals(executor.shuffleManager)) { return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else if ("hash".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); } else { throw new UnsupportedOperationException( "Unsupported shuffle manager: " + executor.shuffleManager); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index f573d962fe36..0ea631ea14d7 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -221,8 +221,7 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { // Register an executor so that the next steps work. ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( - new String[] { System.getProperty("java.io.tmpdir") }, 1, - "org.apache.spark.shuffle.sort.SortShuffleManager"); + new String[] { System.getProperty("java.io.tmpdir") }, 1, "sort"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index a9958232a1d2..60a1b8b0451f 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -83,7 +83,7 @@ public void testBadRequests() throws IOException { // Nonexistent shuffle block resolver.registerExecutor("app0", "exec3", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); try { resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); fail("Should have failed"); @@ -96,7 +96,7 @@ public void testBadRequests() throws IOException { public void testSortShuffleBlocks() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + dataContext.createExecutorInfo("sort")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); @@ -115,7 +115,7 @@ public void testSortShuffleBlocks() throws IOException { public void testHashShuffleBlocks() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + dataContext.createExecutorInfo("hash")); InputStream block0Stream = resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 2095f41d79c1..5e706bf40169 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -49,8 +49,8 @@ public class ExternalShuffleIntegrationSuite { static String APP_ID = "app-id"; - static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; - static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager"; + static String SORT_MANAGER = "sort"; + static String HASH_MANAGER = "hash"; // Executor 0 is sort-based static TestShuffleDataContext dataContext0; From c2de99a7c3a52b0da96517c7056d2733ef45495f Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 15 Dec 2015 18:20:00 -0800 Subject: [PATCH 1773/1824] [SPARK-12351][MESOS] Add documentation about submitting Spark with mesos cluster mode. Adding more documentation about submitting jobs with mesos cluster mode. Author: Timothy Chen Closes #10086 from tnachen/mesos_supervise_docs. --- docs/running-on-mesos.md | 26 +++++++++++++++++++++----- docs/submitting-applications.md | 15 ++++++++++++++- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index a197d0e37302..3193e1785348 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -150,14 +150,30 @@ it does not need to be redundantly passed in as a system property. Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client can find the results of the driver from the Mesos Web UI. -To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, -passing in the Mesos master url (e.g: mesos://host:5050). +To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url -to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL +to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the Spark cluster Web UI. -Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves. +For example: +{% highlight bash %} +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master mesos://207.184.161.138:7077 \ + --deploy-mode cluster + --supervise + --executor-memory 20G \ + --total-executor-cores 100 \ + http://path/to/examples.jar \ + 1000 +{% endhighlight %} + + +Note that jars or python files that are passed to spark-submit should be URIs reachable by Mesos slaves, as the Spark driver doesn't automatically upload local jars. # Mesos Run Modes diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index ac2a14eb56fe..acbb0f298fe4 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -115,6 +115,18 @@ export HADOOP_CONF_DIR=XXX --master spark://207.184.161.138:7077 \ examples/src/main/python/pi.py \ 1000 + +# Run on a Mesos cluster in cluster deploy mode with supervise +./bin/spark-submit \ + --class org.apache.spark.examples.SparkPi \ + --master mesos://207.184.161.138:7077 \ + --deploy-mode cluster + --supervise + --executor-memory 20G \ + --total-executor-cores 100 \ + http://path/to/examples.jar \ + 1000 + {% endhighlight %} # Master URLs @@ -132,9 +144,10 @@ The master URL passed to Spark can be in one of the following formats:
    ") + } else { + if (!forceAdd) { + stackTrace.remove() + } + } +} + +function expandAllThreadStackTrace(toggleButton) { + $('.accordion-heading').each(function() { + //get thread ID + if (!$(this).hasClass("hidden")) { + var trId = $(this).attr('id').match(/thread_([0-9]+)_tr/m)[1] + toggleThreadStackTrace(trId, true) + } + }) + if (toggleButton) { + $('.expandbutton').toggleClass('hidden') + } +} + +function collapseAllThreadStackTrace(toggleButton) { + $('.accordion-body').each(function() { + $(this).remove() + }) + if (toggleButton) { + $('.expandbutton').toggleClass('hidden'); + } +} + + +// inOrOut - true: over, false: out +function onMouseOverAndOut(threadId) { + $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); +} + +function onSearchStringChange() { + var searchString = $('#search').val().toLowerCase(); + //remove the stacktrace + collapseAllThreadStackTrace(false) + if (searchString.length == 0) { + $('tr').each(function() { + $(this).removeClass('hidden') + }) + } else { + $('tr').each(function(){ + if($(this).attr('id') && $(this).attr('id').match(/thread_[0-9]+_tr/) ) { + var children = $(this).children() + var found = false + for (i = 0; i < children.length; i++) { + if (children.eq(i).text().toLowerCase().indexOf(searchString) >= 0) { + found = true + } + } + if (found) { + $(this).removeClass('hidden') + } else { + $(this).addClass('hidden') + } + } + }); + } +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index c628a0c70655..b54e33a96fa2 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -221,10 +221,8 @@ a.expandbutton { cursor: pointer; } -.executor-thread { - background: #E6E6E6; -} - -.non-executor-thread { - background: #FAFAFA; +.threaddump-td-mouseover { + background-color: #49535a !important; + color: white; + cursor:pointer; } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index b0a2cb4aa4d4..58575d154ce5 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -60,44 +60,49 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } } }.map { thread => - val threadName = thread.threadName - val className = "accordion-heading " + { - if (threadName.contains("Executor task launch")) { - "executor-thread" - } else { - "non-executor-thread" - } - } -
    - -
    + + + + + + } + +
    +

    Updated at {UIUtils.formatDate(time)}

    + { + // scalastyle:off +

    + Expand All +

    +

    +
    +
    +
    +
    + Search:
    +
    +

    + // scalastyle:on } - -
    -

    Updated at {UIUtils.formatDate(time)}

    - { - // scalastyle:off -

    - Expand All -

    -

    - // scalastyle:on - } -
    {dumpRows}
    -
    +
    MLlib modelPMML model
    `spark.mllib` modelPMML model
    spark.replClassServer.port(random) - Port for the driver's HTTP class server to listen on. - This is only relevant for the Spark shell. -
    spark.rpc.numRetries 3Jetty-based. Not used by TorrentBroadcast, which sends data through the block manager instead.
    ExecutorDriver(random)Class file serverspark.replClassServer.portJetty-based. Only used in Spark shells.
    Executor / Driver Executor / Driver
    spark.memory.offHeap.enabledtrue + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. +
    spark.memory.offHeap.size0 + The absolute amount of memory which can be used for off-heap allocation. + This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This must be set to a positive value when spark.memory.offHeap.enabled=true. +
    spark.memory.useLegacyMode false
    mesos://HOST:PORT Connect to the given Mesos cluster. The port must be whichever one your is configured to use, which is 5050 by default. Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... + To submit with --deploy-mode cluster, the HOST:PORT should be configured to connect to the MesosClusterDispatcher.
    yarn Connect to a YARN cluster in - client or cluster mode depending on the value of --deploy-mode. + client or cluster mode depending on the value of --deploy-mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable.
    yarn-client Equivalent to yarn with --deploy-mode client, From a63d9edcfb8a714a17492517927aa114dea8fea0 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 15 Dec 2015 18:21:00 -0800 Subject: [PATCH 1774/1824] [SPARK-9516][UI] Improvement of Thread Dump Page https://issues.apache.org/jira/browse/SPARK-9516 - [x] new look of Thread Dump Page - [x] click column title to sort - [x] grep - [x] search as you type squito JoshRosen It's ready for the review now Author: CodingCat Closes #7910 from CodingCat/SPARK-9516. --- .../org/apache/spark/ui/static/sorttable.js | 6 +- .../org/apache/spark/ui/static/table.js | 72 ++++++++++++++++++ .../org/apache/spark/ui/static/webui.css | 10 +-- .../ui/exec/ExecutorThreadDumpPage.scala | 73 ++++++++++--------- 4 files changed, 118 insertions(+), 43 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index a73d9a5cbc21..ff241470f32d 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -169,7 +169,7 @@ sorttable = { for (var i=0; i
    " +
    +            stackTraceText +  "
    {threadId}{thread.threadName}{thread.threadState}
    + + + + + + {dumpRows} +
    Thread IDThread NameThread State
    +
    }.getOrElse(Text("Error fetching thread dump")) UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) } From 765a488494dac0ed38d2b81742c06467b79d96b2 Mon Sep 17 00:00:00 2001 From: "Richard W. Eggert II" Date: Tue, 15 Dec 2015 18:22:58 -0800 Subject: [PATCH 1775/1824] [SPARK-9026][SPARK-4514] Modifications to JobWaiter, FutureAction, and AsyncRDDActions to support non-blocking operation These changes rework the implementations of `SimpleFutureAction`, `ComplexFutureAction`, `JobWaiter`, and `AsyncRDDActions` such that asynchronous callbacks on the generated `Futures` NEVER block waiting for a job to complete. A small amount of mutex synchronization is necessary to protect the internal fields that manage cancellation, but these locks are only held very briefly and in practice should almost never cause any blocking to occur. The existing blocking APIs of these classes are retained, but they simply delegate to the underlying non-blocking API and `Await` the results with indefinite timeouts. Associated JIRA ticket: https://issues.apache.org/jira/browse/SPARK-9026 Also fixes: https://issues.apache.org/jira/browse/SPARK-4514 This pull request contains all my own original work, which I release to the Spark project under its open source license. Author: Richard W. Eggert II Closes #9264 from reggert/fix-futureaction. --- .../scala/org/apache/spark/FutureAction.scala | 164 +++++++----------- .../apache/spark/rdd/AsyncRDDActions.scala | 48 ++--- .../apache/spark/scheduler/DAGScheduler.scala | 8 +- .../apache/spark/scheduler/JobWaiter.scala | 48 ++--- .../test/scala/org/apache/spark/Smuggle.scala | 82 +++++++++ .../org/apache/spark/StatusTrackerSuite.scala | 26 +++ .../spark/rdd/AsyncRDDActionsSuite.scala | 33 +++- 7 files changed, 251 insertions(+), 158 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/Smuggle.scala diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 48792a958130..2a8220ff4009 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -20,13 +20,15 @@ package org.apache.spark import java.util.Collections import java.util.concurrent.TimeUnit +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.Try + +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobWaiter -import scala.concurrent._ -import scala.concurrent.duration.Duration -import scala.util.{Failure, Try} /** * A future for the result of an action to support cancellation. This is an extension of the @@ -105,6 +107,7 @@ trait FutureAction[T] extends Future[T] { * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ +@DeveloperApi class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { @@ -116,142 +119,96 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { - if (!atMost.isFinite()) { - awaitResult() - } else jobWaiter.synchronized { - val finishTime = System.currentTimeMillis() + atMost.toMillis - while (!isCompleted) { - val time = System.currentTimeMillis() - if (time >= finishTime) { - throw new TimeoutException - } else { - jobWaiter.wait(finishTime - time) - } - } - } + jobWaiter.completionFuture.ready(atMost) this } @throws(classOf[Exception]) override def result(atMost: Duration)(implicit permit: CanAwait): T = { - ready(atMost)(permit) - awaitResult() match { - case scala.util.Success(res) => res - case scala.util.Failure(e) => throw e - } + jobWaiter.completionFuture.ready(atMost) + assert(value.isDefined, "Future has not completed properly") + value.get.get } override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { - executor.execute(new Runnable { - override def run() { - func(awaitResult()) - } - }) + jobWaiter.completionFuture onComplete {_ => func(value.get)} } override def isCompleted: Boolean = jobWaiter.jobFinished override def isCancelled: Boolean = _cancelled - override def value: Option[Try[T]] = { - if (jobWaiter.jobFinished) { - Some(awaitResult()) - } else { - None - } - } - - private def awaitResult(): Try[T] = { - jobWaiter.awaitResult() match { - case JobSucceeded => scala.util.Success(resultFunc) - case JobFailed(e: Exception) => scala.util.Failure(e) - } - } + override def value: Option[Try[T]] = + jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)} def jobIds: Seq[Int] = Seq(jobWaiter.jobId) } +/** + * Handle via which a "run" function passed to a [[ComplexFutureAction]] + * can submit jobs for execution. + */ +@DeveloperApi +trait JobSubmitter { + /** + * Submit a job for execution and return a FutureAction holding the result. + * This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ + def submitJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R): FutureAction[R] +} + + /** * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, - * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the - * action thread if it is being blocked by a job. + * takeSample. Cancellation works by setting the cancelled flag to true and cancelling any pending + * jobs. */ -class ComplexFutureAction[T] extends FutureAction[T] { +@DeveloperApi +class ComplexFutureAction[T](run : JobSubmitter => Future[T]) + extends FutureAction[T] { self => - // Pointer to the thread that is executing the action. It is set when the action is run. - @volatile private var thread: Thread = _ + @volatile private var _cancelled = false - // A flag indicating whether the future has been cancelled. This is used in case the future - // is cancelled before the action was even run (and thus we have no thread to interrupt). - @volatile private var _cancelled: Boolean = false - - @volatile private var jobs: Seq[Int] = Nil + @volatile private var subActions: List[FutureAction[_]] = Nil // A promise used to signal the future. - private val p = promise[T]() + private val p = Promise[T]().tryCompleteWith(run(jobSubmitter)) - override def cancel(): Unit = this.synchronized { + override def cancel(): Unit = synchronized { _cancelled = true - if (thread != null) { - thread.interrupt() - } - } - - /** - * Executes some action enclosed in the closure. To properly enable cancellation, the closure - * should use runJob implementation in this promise. See takeAsync for example. - */ - def run(func: => T)(implicit executor: ExecutionContext): this.type = { - scala.concurrent.future { - thread = Thread.currentThread - try { - p.success(func) - } catch { - case e: Exception => p.failure(e) - } finally { - // This lock guarantees when calling `thread.interrupt()` in `cancel`, - // thread won't be set to null. - ComplexFutureAction.this.synchronized { - thread = null - } - } - } - this + p.tryFailure(new SparkException("Action has been cancelled")) + subActions.foreach(_.cancel()) } - /** - * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext - * to enable cancellation. - */ - def runJob[T, U, R]( + private def jobSubmitter = new JobSubmitter { + def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit, - resultFunc: => R) { - // If the action hasn't been cancelled yet, submit the job. The check and the submitJob - // command need to be in an atomic block. - val job = this.synchronized { + resultFunc: => R): FutureAction[R] = self.synchronized { + // If the action hasn't been cancelled yet, submit the job. The check and the submitJob + // command need to be in an atomic block. if (!isCancelled) { - rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) + val job = rdd.context.submitJob( + rdd, + processPartition, + partitions, + resultHandler, + resultFunc) + subActions = job :: subActions + job } else { throw new SparkException("Action has been cancelled") } } - - this.jobs = jobs ++ job.jobIds - - // Wait for the job to complete. If the action is cancelled (with an interrupt), - // cancel the job and stop the execution. This is not in a synchronized block because - // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. - try { - Await.ready(job, Duration.Inf) - } catch { - case e: InterruptedException => - job.cancel() - throw new SparkException("Action has been cancelled") - } } override def isCancelled: Boolean = _cancelled @@ -276,10 +233,11 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def value: Option[Try[T]] = p.future.value - def jobIds: Seq[Int] = jobs + def jobIds: Seq[Int] = subActions.flatMap(_.jobIds) } + private[spark] class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) extends JavaFutureAction[T] { @@ -303,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S Await.ready(futureAction, timeout) futureAction.value.get match { case scala.util.Success(value) => converter(value) - case Failure(exception) => + case scala.util.Failure(exception) => if (isCancelled) { throw new CancellationException("Job cancelled").initCause(exception) } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index d5e853613b05..14f541f937b4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,13 +19,12 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.util.ThreadUtils - import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext +import scala.concurrent.{Future, ExecutionContext} import scala.reflect.ClassTag -import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.util.ThreadUtils /** * A set of asynchronous RDD actions available through an implicit conversion. @@ -65,17 +64,23 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for retrieving the first num elements of the RDD. */ def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { - val f = new ComplexFutureAction[Seq[T]] val callSite = self.context.getCallSite - - f.run { - // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which - // is a cached thread pool. - val results = new ArrayBuffer[T](num) - val totalParts = self.partitions.length - var partsScanned = 0 - self.context.setCallSite(callSite) - while (results.size < num && partsScanned < totalParts) { + val localProperties = self.context.getLocalProperties + // Cached thread pool to handle aggregation of subtasks. + implicit val executionContext = AsyncRDDActions.futureExecutionContext + val results = new ArrayBuffer[T](num) + val totalParts = self.partitions.length + + /* + Recursively triggers jobs to scan partitions until either the requested + number of elements are retrieved, or the partitions to scan are exhausted. + This implementation is non-blocking, asynchronously handling the + results of each job and triggering the next job using callbacks on futures. + */ + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + if (results.size >= num || partsScanned >= totalParts) { + Future.successful(results.toSeq) + } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 @@ -97,19 +102,20 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val buf = new Array[Array[T]](p.size) - f.runJob(self, + self.context.setCallSite(callSite) + self.context.setLocalProperties(localProperties) + val job = jobSubmitter.submitJob(self, (it: Iterator[T]) => it.take(left).toArray, p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - - buf.foreach(results ++= _.take(num - results.size)) - partsScanned += numPartsToTry + job.flatMap {_ => + buf.foreach(results ++= _.take(num - results.size)) + continue(partsScanned + numPartsToTry) + } } - results.toSeq - }(AsyncRDDActions.futureExecutionContext) - f + new ComplexFutureAction[Seq[T]](continue(0)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5582720bbcff..8d0e0c8624a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.mutable.{HashMap, HashSet, Stack} +import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.existentials import scala.language.postfixOps @@ -610,11 +611,12 @@ class DAGScheduler( properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) - waiter.awaitResult() match { - case JobSucceeded => + Await.ready(waiter.completionFuture, atMost = Duration.Inf) + waiter.completionFuture.value.get match { + case scala.util.Success(_) => logInfo("Job %d finished: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) - case JobFailed(exception: Exception) => + case scala.util.Failure(exception) => logInfo("Job %d failed: %s, took %f s".format (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 382b09422a4a..4326135186a7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,6 +17,10 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{Future, Promise} + /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. @@ -28,17 +32,15 @@ private[spark] class JobWaiter[T]( resultHandler: (Int, T) => Unit) extends JobListener { - private var finishedTasks = 0 - - // Is the job as a whole finished (succeeded or failed)? - @volatile - private var _jobFinished = totalTasks == 0 - - def jobFinished: Boolean = _jobFinished - + private val finishedTasks = new AtomicInteger(0) // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero // partition RDDs), we set the jobResult directly to JobSucceeded. - private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + private val jobPromise: Promise[Unit] = + if (totalTasks == 0) Promise.successful(()) else Promise() + + def jobFinished: Boolean = jobPromise.isCompleted + + def completionFuture: Future[Unit] = jobPromise.future /** * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled @@ -49,29 +51,17 @@ private[spark] class JobWaiter[T]( dagScheduler.cancelJob(jobId) } - override def taskSucceeded(index: Int, result: Any): Unit = synchronized { - if (_jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + override def taskSucceeded(index: Int, result: Any): Unit = { + // resultHandler call must be synchronized in case resultHandler itself is not thread safe. + synchronized { + resultHandler(index, result.asInstanceOf[T]) } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - _jobFinished = true - jobResult = JobSucceeded - this.notifyAll() + if (finishedTasks.incrementAndGet() == totalTasks) { + jobPromise.success(()) } } - override def jobFailed(exception: Exception): Unit = synchronized { - _jobFinished = true - jobResult = JobFailed(exception) - this.notifyAll() - } + override def jobFailed(exception: Exception): Unit = + jobPromise.failure(exception) - def awaitResult(): JobResult = synchronized { - while (!_jobFinished) { - this.wait() - } - return jobResult - } } diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala new file mode 100644 index 000000000000..01694a6e6f74 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.UUID +import java.util.concurrent.locks.ReentrantReadWriteLock + +import scala.collection.mutable + +/** + * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. + * This is intended for testing purposes, primarily to make locks, semaphores, and + * other constructs that would not survive serialization available from within tasks. + * A Smuggle reference is itself serializable, but after being serialized and + * deserialized, it still refers to the same underlying "smuggled" object, as long + * as it was deserialized within the same JVM. This can be useful for tests that + * depend on the timing of task completion to be deterministic, since one can "smuggle" + * a lock or semaphore into the task, and then the task can block until the test gives + * the go-ahead to proceed via the lock. + */ +class Smuggle[T] private(val key: Symbol) extends Serializable { + def smuggledObject: T = Smuggle.get(key) +} + + +object Smuggle { + /** + * Wraps the specified object to be smuggled into a serialized task without + * being serialized itself. + * + * @param smuggledObject + * @tparam T + * @return Smuggle wrapper around smuggledObject. + */ + def apply[T](smuggledObject: T): Smuggle[T] = { + val key = Symbol(UUID.randomUUID().toString) + lock.writeLock().lock() + try { + smuggledObjects += key -> smuggledObject + } finally { + lock.writeLock().unlock() + } + new Smuggle(key) + } + + private val lock = new ReentrantReadWriteLock + private val smuggledObjects = mutable.WeakHashMap.empty[Symbol, Any] + + private def get[T](key: Symbol) : T = { + lock.readLock().lock() + try { + smuggledObjects(key).asInstanceOf[T] + } finally { + lock.readLock().unlock() + } + } + + /** + * Implicit conversion of a Smuggle wrapper to the object being smuggled. + * + * @param smuggle the wrapper to unpack. + * @tparam T + * @return the smuggled object represented by the wrapper. + */ + implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject + +} diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 46516e8d2529..5483f2b8434a 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -86,4 +86,30 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont Set(firstJobId, secondJobId)) } } + + test("getJobIdsForGroup() with takeAsync()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId)) + } + } + + test("getJobIdsForGroup() with takeAsync() across multiple partitions") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2 + } + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index ec99f2a1bad6..de015ebd5d23 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore -import scala.concurrent.{Await, TimeoutException} +import scala.concurrent._ import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark._ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @@ -197,4 +197,33 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim Await.result(f, Duration(20, "milliseconds")) } } + + private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { + val executionContextInvoked = Promise[Unit] + val fakeExecutionContext = new ExecutionContext { + override def execute(runnable: Runnable): Unit = { + executionContextInvoked.success(()) + } + override def reportFailure(t: Throwable): Unit = () + } + val starter = Smuggle(new Semaphore(0)) + starter.drainPermits() + val rdd = sc.parallelize(1 to 100, 4).mapPartitions {itr => starter.acquire(1); itr} + val f = action(rdd) + f.onComplete(_ => ())(fakeExecutionContext) + // Here we verify that registering the callback didn't cause a thread to be consumed. + assert(!executionContextInvoked.isCompleted) + // Now allow the executors to proceed with task processing. + starter.release(rdd.partitions.length) + // Waiting for the result verifies that the tasks were successfully processed. + Await.result(executionContextInvoked.future, atMost = 15.seconds) + } + + test("SimpleFutureAction callback must not consume a thread while waiting") { + testAsyncAction(_.countAsync()) + } + + test("ComplexFutureAction callback must not consume a thread while waiting") { + testAsyncAction((_.takeAsync(100))) + } } From 63ccdef81329e785807f37b4e918a9247fc70e3c Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 15 Dec 2015 18:24:23 -0800 Subject: [PATCH 1776/1824] [SPARK-10123][DEPLOY] Support specifying deploy mode from configuration Please help to review, thanks a lot. Author: jerryshao Closes #10195 from jerryshao/SPARK-10123. --- .../spark/deploy/SparkSubmitArguments.scala | 5 ++- .../spark/deploy/SparkSubmitSuite.scala | 41 +++++++++++++++++++ docs/configuration.md | 15 +++++-- .../apache/spark/launcher/SparkLauncher.java | 3 ++ .../launcher/SparkSubmitCommandBuilder.java | 7 ++-- 5 files changed, 64 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 18a1c52ae53f..915ef81b4eae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -176,7 +176,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull - deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull + deployMode = Option(deployMode) + .orElse(sparkProperties.get("spark.submit.deployMode")) + .orElse(env.get("DEPLOY_MODE")) + .orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d494b0caab85..2626f5a16dfb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -136,6 +136,47 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("specify deploy mode through configuration") { + val clArgs = Seq( + "--master", "yarn", + "--conf", "spark.submit.deployMode=client", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs = new SparkSubmitArguments(clArgs) + val (_, _, sysProps, _) = prepareSubmitEnvironment(appArgs) + + appArgs.deployMode should be ("client") + sysProps("spark.submit.deployMode") should be ("client") + + // Both cmd line and configuration are specified, cmdline option takes the priority + val clArgs1 = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--conf", "spark.submit.deployMode=client", + "-class", "org.SomeClass", + "thejar.jar" + ) + val appArgs1 = new SparkSubmitArguments(clArgs1) + val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + + appArgs1.deployMode should be ("cluster") + sysProps1("spark.submit.deployMode") should be ("cluster") + + // Neither cmdline nor configuration are specified, client mode is the default choice + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "thejar.jar" + ) + val appArgs2 = new SparkSubmitArguments(clArgs2) + appArgs2.deployMode should be (null) + + val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + appArgs2.deployMode should be ("client") + sysProps2("spark.submit.deployMode") should be ("client") + } + test("handles YARN cluster mode") { val clArgs = Seq( "--deploy-mode", "cluster", diff --git a/docs/configuration.md b/docs/configuration.md index 55cf4b2dac5f..38d3d059f9d3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -48,7 +48,7 @@ The following format is accepted: 1y (years) -Properties that specify a byte size should be configured with a unit of size. +Properties that specify a byte size should be configured with a unit of size. The following format is accepted: 1b (bytes) @@ -192,6 +192,15 @@ of the most common options to set are: allowed master URL's. + + spark.submit.deployMode + (none) + + The deploy mode of Spark driver program, either "client" or "cluster", + Which means to launch driver program locally ("client") + or remotely ("cluster") on one of the nodes inside the cluster. + + Apart from these, the following properties are also available, and may be useful in some situations: @@ -1095,7 +1104,7 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.lookupTimeout 120s - Duration for an RPC remote endpoint lookup operation to wait before timing out. + Duration for an RPC remote endpoint lookup operation to wait before timing out. @@ -1559,7 +1568,7 @@ Apart from these, the following properties are also available, and may be useful spark.streaming.stopGracefullyOnShutdown false - If true, Spark shuts down the StreamingContext gracefully on JVM + If true, Spark shuts down the StreamingContext gracefully on JVM shutdown rather than immediately. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index dd1c93af6ca4..20e6003a00c1 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -40,6 +40,9 @@ public class SparkLauncher { /** The Spark master. */ public static final String SPARK_MASTER = "spark.master"; + /** The Spark deploy mode. */ + public static final String DEPLOY_MODE = "spark.submit.deployMode"; + /** Configuration key for the driver memory. */ public static final String DRIVER_MEMORY = "spark.driver.memory"; /** Configuration key for the driver class path. */ diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 312df0b269f3..a95f0f17517d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -294,10 +294,11 @@ private void constructEnvVarArgs( private boolean isClientMode(Map userProps) { String userMaster = firstNonEmpty(master, userProps.get(SparkLauncher.SPARK_MASTER)); - // Default master is "local[*]", so assume client mode in that case. + String userDeployMode = firstNonEmpty(deployMode, userProps.get(SparkLauncher.DEPLOY_MODE)); + // Default master is "local[*]", so assume client mode in that case return userMaster == null || - "client".equals(deployMode) || - (!userMaster.equals("yarn-cluster") && deployMode == null); + "client".equals(userDeployMode) || + (!userMaster.equals("yarn-cluster") && userDeployMode == null); } /** From 8a215d2338c6286253e20122640592f9d69896c8 Mon Sep 17 00:00:00 2001 From: Naveen Date: Tue, 15 Dec 2015 18:25:22 -0800 Subject: [PATCH 1777/1824] [SPARK-9886][CORE] Fix to use ShutdownHookManager in ExternalBlockStore.scala Author: Naveen Closes #10313 from naveenminchu/branch-fix-SPARK-9886. --- .../spark/storage/ExternalBlockStore.scala | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index db965d54bafd..94883a54a74e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** @@ -177,15 +177,6 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: } } - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("ExternalBlockStore shutdown hook") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - externalBlockManager.map(_.shutdown()) - } - }) - } - // Create concrete block manager and fall back to Tachyon by default for backward compatibility. private def createBlkManager(): Option[ExternalBlockManager] = { val clsName = blockManager.conf.getOption(ExternalBlockStore.BLOCK_MANAGER_NAME) @@ -196,7 +187,10 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: .newInstance() .asInstanceOf[ExternalBlockManager] instance.init(blockManager, executorId) - addShutdownHook(); + ShutdownHookManager.addShutdownHook { () => + logDebug("Shutdown hook called") + externalBlockManager.map(_.shutdown()) + } Some(instance) } catch { case NonFatal(t) => From c5b6b398d5e368626e589feede80355fb74c2bd8 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 15 Dec 2015 18:28:16 -0800 Subject: [PATCH 1778/1824] [SPARK-12062][CORE] Change Master to asyc rebuild UI when application completes This change builds the event history of completed apps asynchronously so the RPC thread will not be blocked and allow new workers to register/remove if the event log history is very large and takes a long time to rebuild. Author: Bryan Cutler Closes #10284 from BryanCutler/async-MasterUI-SPARK-12062. --- .../apache/spark/deploy/master/Master.scala | 79 ++++++++++++------- .../spark/deploy/master/MasterMessages.scala | 2 + 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1355e1ad1b52..fc42bf06e40a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,9 +21,11 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date -import java.util.concurrent.{ScheduledFuture, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps import scala.util.Random @@ -56,6 +58,10 @@ private[deploy] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + private val rebuildUIThread = + ThreadUtils.newDaemonSingleThreadExecutor("master-rebuild-ui-thread") + private val rebuildUIContext = ExecutionContext.fromExecutor(rebuildUIThread) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -78,7 +84,8 @@ private[deploy] class Master( private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 - private val appIdToUI = new HashMap[String, SparkUI] + // Using ConcurrentHashMap so that master-rebuild-ui-thread can add a UI after asyncRebuildUI + private val appIdToUI = new ConcurrentHashMap[String, SparkUI] private val drivers = new HashSet[DriverInfo] private val completedDrivers = new ArrayBuffer[DriverInfo] @@ -191,6 +198,7 @@ private[deploy] class Master( checkForWorkerTimeOutTask.cancel(true) } forwardMessageThread.shutdownNow() + rebuildUIThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -367,6 +375,10 @@ private[deploy] class Master( case CheckForWorkerTimeOut => { timeOutDeadWorkers() } + + case AttachCompletedRebuildUI(appId) => + // An asyncRebuildSparkUI has completed, so need to attach to master webUi + Option(appIdToUI.get(appId)).foreach { ui => webUi.attachSparkUI(ui) } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -809,7 +821,7 @@ private[deploy] class Master( if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { - appIdToUI.remove(a.id).foreach { ui => webUi.detachSparkUI(ui) } + Option(appIdToUI.remove(a.id)).foreach { ui => webUi.detachSparkUI(ui) } applicationMetricsSystem.removeSource(a.appSource) }) completedApps.trimStart(toRemove) @@ -818,7 +830,7 @@ private[deploy] class Master( waitingApps -= app // If application events are logged, use them to rebuild the UI - rebuildSparkUI(app) + asyncRebuildSparkUI(app) for (exec <- app.executors.values) { killExecutor(exec) @@ -923,49 +935,57 @@ private[deploy] class Master( * Return the UI if successful, else None */ private[master] def rebuildSparkUI(app: ApplicationInfo): Option[SparkUI] = { + val futureUI = asyncRebuildSparkUI(app) + Await.result(futureUI, Duration.Inf) + } + + /** Rebuild a new SparkUI asynchronously to not block RPC event loop */ + private[master] def asyncRebuildSparkUI(app: ApplicationInfo): Future[Option[SparkUI]] = { val appName = app.desc.name val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" - try { - val eventLogDir = app.desc.eventLogDir - .getOrElse { - // Event logging is not enabled for this application - app.appUIUrlAtHistoryServer = Some(notFoundBasePath) - return None - } - + val eventLogDir = app.desc.eventLogDir + .getOrElse { + // Event logging is disabled for this application + app.appUIUrlAtHistoryServer = Some(notFoundBasePath) + return Future.successful(None) + } + val futureUI = Future { val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) + eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + - EventLoggingListener.IN_PROGRESS)) + EventLoggingListener.IN_PROGRESS)) - if (inProgressExists) { + val eventLogFile = if (inProgressExists) { // Event logging is enabled for this application, but the application is still in progress logWarning(s"Application $appName is still in progress, it may be terminated abnormally.") - } - - val (eventLogFile, status) = if (inProgressExists) { - (eventLogFilePrefix + EventLoggingListener.IN_PROGRESS, " (in progress)") + eventLogFilePrefix + EventLoggingListener.IN_PROGRESS } else { - (eventLogFilePrefix, " (completed)") + eventLogFilePrefix } val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) - val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { - replayBus.replay(logInput, eventLogFile, maybeTruncated) + replayBus.replay(logInput, eventLogFile, inProgressExists) } finally { logInput.close() } - appIdToUI(app.id) = ui - webUi.attachSparkUI(ui) + + Some(ui) + }(rebuildUIContext) + + futureUI.onSuccess { case Some(ui) => + appIdToUI.put(app.id, ui) + self.send(AttachCompletedRebuildUI(app.id)) // Application UI is successfully rebuilt, so link the Master UI to it + // NOTE - app.appUIUrlAtHistoryServer is volatile app.appUIUrlAtHistoryServer = Some(ui.basePath) - Some(ui) - } catch { + }(ThreadUtils.sameThread) + + futureUI.onFailure { case fnf: FileNotFoundException => // Event logging is enabled for this application, but no event logs are found val title = s"Application history not found (${app.id})" @@ -974,7 +994,7 @@ private[deploy] class Master( msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&title=$title") - None + case e: Exception => // Relay exception message to application UI page val title = s"Application history load error (${app.id})" @@ -984,8 +1004,9 @@ private[deploy] class Master( msg = URLEncoder.encode(msg, "UTF-8") app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title") - None - } + }(ThreadUtils.sameThread) + + futureUI } /** Generate a new app ID given a app's submission date */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index a952cee36eb4..a055d097674c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -39,4 +39,6 @@ private[master] object MasterMessages { case object BoundPortsRequest case class BoundPortsResponse(rpcEndpointPort: Int, webUIPort: Int, restPort: Option[Int]) + + case class AttachCompletedRebuildUI(appId: String) } From a89e8b6122ee5a1517fbcf405b1686619db56696 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 15 Dec 2015 18:29:19 -0800 Subject: [PATCH 1779/1824] [SPARK-10477][SQL] using DSL in ColumnPruningSuite to improve readability Author: Wenchen Fan Closes #8645 from cloud-fan/test. --- .../spark/sql/catalyst/dsl/package.scala | 7 ++-- .../optimizer/ColumnPruningSuite.scala | 41 +++++++++++-------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index af594c25c54c..e50971173c49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -275,13 +275,14 @@ package object dsl { def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + alias: Option[String] = None, + outputNames: Seq[String] = Nil): LogicalPlan = + Generate(generator, join = join, outer = outer, alias, + outputNames.map(UnresolvedAttribute(_)), logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 4a1e7ceaf394..9bf61ae09178 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.Explode import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -35,12 +35,11 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning for Generate when Generate.join = false") { val input = LocalRelation('a.int, 'b.array(StringType)) - val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze + val query = input.generate(Explode('b), join = false).analyze + val optimized = Optimize.execute(query) - val correctAnswer = - Generate(Explode('b), false, false, None, 's.string :: Nil, - Project('b.attr :: Nil, input)).analyze + val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze comparePlans(optimized, correctAnswer) } @@ -49,16 +48,19 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) val query = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - Project(Seq('a, 'c), - input))).analyze + input + .select('a, 'c) + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze comparePlans(optimized, correctAnswer) } @@ -67,15 +69,18 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('b.array(StringType)) val query = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = true, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), false, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = false, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze comparePlans(optimized, correctAnswer) } From ca0690b5ef10b14ce57a0c30d5308eb02f163f39 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Tue, 15 Dec 2015 18:30:59 -0800 Subject: [PATCH 1780/1824] [SPARK-4117][YARN] Spark on Yarn handle AM being told command from RM Spark on Yarn handle AM being told command from RM When RM throws ApplicationAttemptNotFoundException for allocate invocation, making the ApplicationMaster to finish immediately without any retries. Author: Devaraj K Closes #10129 from devaraj-kavali/SPARK-4117. --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1970f7d150fe..fc742df73d73 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -376,7 +376,14 @@ private[spark] class ApplicationMaster( case i: InterruptedException => case e: Throwable => { failureCount += 1 - if (!NonFatal(e) || failureCount >= reporterMaxFailures) { + // this exception was introduced in hadoop 2.4 and this code would not compile + // with earlier versions if we refer it directly. + if ("org.apache.hadoop.yarn.exceptions.ApplicationAttemptNotFoundException" == + e.getClass().getName()) { + logError("Exception from Reporter thread.", e) + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, + e.getMessage) + } else if (!NonFatal(e) || failureCount >= reporterMaxFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + s"$failureCount time(s) from Reporter thread.") From d52bf47e13e0186590437f71040100d2f6f11da9 Mon Sep 17 00:00:00 2001 From: proflin Date: Tue, 15 Dec 2015 20:22:56 -0800 Subject: [PATCH 1781/1824] =?UTF-8?q?[SPARK-12304][STREAMING]=20Make=20Spa?= =?UTF-8?q?rk=20Streaming=20web=20UI=20display=20more=20fri=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …endly Receiver graphs Currently, the Spark Streaming web UI uses the same maxY when displays 'Input Rate Times& Histograms' and 'Per-Receiver Times& Histograms'. This may lead to somewhat un-friendly graphs: once we have tens of Receivers or more, every 'Per-Receiver Times' line almost hits the ground. This issue proposes to calculate a new maxY against the original one, which is shared among all the `Per-Receiver Times& Histograms' graphs. Before: ![before-5](https://cloud.githubusercontent.com/assets/15843379/11761362/d790c356-a0fa-11e5-860e-4b834603de1d.png) After: ![after-5](https://cloud.githubusercontent.com/assets/15843379/11761361/cfabf692-a0fa-11e5-97d0-4ad124aaca2a.png) Author: proflin Closes #10318 from proflin/SPARK-12304. --- .../org/apache/spark/streaming/ui/StreamingPage.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 88a4483e8068..b3692c3ea302 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -392,9 +392,15 @@ private[ui] class StreamingPage(parent: StreamingTab) maxX: Long, minY: Double, maxY: Double): Seq[Node] = { + val maxYCalculated = listener.receivedEventRateWithBatchTime.values + .flatMap { case streamAndRates => streamAndRates.map { case (_, eventRate) => eventRate } } + .reduceOption[Double](math.max) + .map(_.ceil.toLong) + .getOrElse(0L) + val content = listener.receivedEventRateWithBatchTime.toList.sortBy(_._1).map { case (streamId, eventRates) => - generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxY) + generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxYCalculated) }.foldLeft[Seq[Node]](Nil)(_ ++ _) // scalastyle:off From 0f6936b5f1c9b0be1c33b98ffb62a72ae0c3e2a8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 22:22:49 -0800 Subject: [PATCH 1782/1824] [SPARK-12249][SQL] JDBC non-equality comparison operator not pushed down. https://issues.apache.org/jira/browse/SPARK-12249 Currently `!=` operator is not pushed down correctly. I simply added a case for this. Author: hyukjinkwon Closes #10233 from HyukjinKwon/SPARK-12249. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 1c348ed62fc7..c18a2d2cc076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -281,6 +281,7 @@ private[sql] class JDBCRDD( */ private def compileFilter(f: Filter): String = f match { case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}" case LessThan(attr, value) => s"$attr < ${compileValue(value)}" case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index a36094715299..aca144305734 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -177,9 +177,11 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("SELECT * WHERE (simple predicates)") { assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size === 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size === 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) } test("SELECT * WHERE (quoted strings)") { From 7f443a6879fa33ca8adb682bd85df2d56fb5fcda Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 22:25:08 -0800 Subject: [PATCH 1783/1824] [SPARK-12314][SQL] isnull operator not pushed down for JDBC datasource. https://issues.apache.org/jira/browse/SPARK-12314 `IsNull` filter is not being pushed down for JDBC datasource. It looks it is SQL standard according to [SQL-92](http://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt), SQL:1999, [SQL:2003](http://www.wiscorp.com/sql_2003_standard.zip) and [SQL:201x](http://www.wiscorp.com/sql20nn.zip) and I believe most databases support this. In this PR, I simply added the case for `IsNull` filter to produce a proper filter string. Author: hyukjinkwon This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #10286 from HyukjinKwon/SPARK-12314. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index c18a2d2cc076..3271b46be18f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -286,6 +286,7 @@ private[sql] class JDBCRDD( case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" case _ => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index aca144305734..0305667ff66e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -182,6 +182,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size === 1) } test("SELECT * WHERE (quoted strings)") { From 2aad2d372469aaf2773876cae98ef002fef03aa3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 15 Dec 2015 22:30:35 -0800 Subject: [PATCH 1784/1824] [SPARK-12315][SQL] isnotnull operator not pushed down for JDBC datasource. https://issues.apache.org/jira/browse/SPARK-12315 `IsNotNull` filter is not being pushed down for JDBC datasource. It looks it is SQL standard according to [SQL-92](http://www.contrib.andrew.cmu.edu/~shadow/sql/sql1992.txt), SQL:1999, [SQL:2003](http://www.wiscorp.com/sql_2003_standard.zip) and [SQL:201x](http://www.wiscorp.com/sql20nn.zip) and I believe most databases support this. In this PR, I simply added the case for `IsNotNull` filter to produce a proper filter string. Author: hyukjinkwon This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #10287 from HyukjinKwon/SPARK-12315. --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 1 + .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 ++ 2 files changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 3271b46be18f..2d38562e0901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -287,6 +287,7 @@ private[sql] class JDBCRDD( case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" case _ => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0305667ff66e..d6aeb523ea8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -183,6 +183,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size === 1) + assert(stripSparkFilter( + sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size === 0) } test("SELECT * WHERE (quoted strings)") { From 554d840a9ade79722c96972257435a05e2aa9d88 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Dec 2015 22:32:51 -0800 Subject: [PATCH 1785/1824] Style fix for the previous 3 JDBC filter push down commits. --- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d6aeb523ea8d..2b91f62c2fa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -176,15 +176,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("SELECT * WHERE (simple predicates)") { - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size === 0) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size === 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size === 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size === 1) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size === 2) - assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size === 2) - assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size === 1) - assert(stripSparkFilter( - sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size === 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) } test("SELECT * WHERE (quoted strings)") { From 18ea11c3a84e5eafd81fa0fe7c09224e79c4e93f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Dec 2015 00:57:07 -0800 Subject: [PATCH 1786/1824] Revert "[HOTFIX] Compile error from commit 31b3910" This reverts commit 840bd2e008da5b22bfa73c587ea2c57666fffc60. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 33b03be1138b..b69d4411425d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -234,7 +234,7 @@ class DataFrame private[sql]( // For Data that has more than "_numRows" records if (hasMoreData) { val rowsString = if (_numRows == 1) "row" else "rows" - sb.append(s"only showing top ${_numRows} $rowsString\n") + sb.append(s"only showing top $_numRows $rowsString\n") } sb.toString() From 1a3d0cd9f013aee1f03b1c632c91ae0951bccbb0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Dec 2015 00:57:34 -0800 Subject: [PATCH 1787/1824] Revert "[SPARK-12105] [SQL] add convenient show functions" This reverts commit 31b391019ff6eb5a483f4b3e62fd082de7ff8416. --- .../org/apache/spark/sql/DataFrame.scala | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index b69d4411425d..497bd4826677 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -160,24 +160,17 @@ class DataFrame private[sql]( } } - /** - * Compose the string representing rows for output - */ - def showString(): String = { - showString(20) - } - /** * Compose the string representing rows for output - * @param numRows Number of rows to show + * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - def showString(numRows: Int, truncate: Boolean = true): String = { - val _numRows = numRows.max(0) + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + val numRows = _numRows.max(0) val sb = new StringBuilder - val takeResult = take(_numRows + 1) - val hasMoreData = takeResult.length > _numRows - val data = takeResult.take(_numRows) + val takeResult = take(numRows + 1) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) val numCols = schema.fieldNames.length // For array values, replace Seq and Array with square brackets @@ -231,10 +224,10 @@ class DataFrame private[sql]( sb.append(sep) - // For Data that has more than "_numRows" records + // For Data that has more than "numRows" records if (hasMoreData) { - val rowsString = if (_numRows == 1) "row" else "rows" - sb.append(s"only showing top $_numRows $rowsString\n") + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") } sb.toString() From a6325fc401f68d9fa30cc947c44acc9d64ebda7b Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Wed, 16 Dec 2015 10:12:33 -0800 Subject: [PATCH 1788/1824] [SPARK-12324][MLLIB][DOC] Fixes the sidebar in the ML documentation This fixes the sidebar, using a pure CSS mechanism to hide it when the browser's viewport is too narrow. Credit goes to the original author Titan-C (mentioned in the NOTICE). Note that I am not a CSS expert, so I can only address comments up to some extent. Default view: screen shot 2015-12-14 at 12 46 39 pm When collapsed manually by the user: screen shot 2015-12-14 at 12 54 02 pm Disappears when column is too narrow: screen shot 2015-12-14 at 12 47 22 pm Can still be opened by the user if necessary: screen shot 2015-12-14 at 12 51 15 pm Author: Timothy Hunter Closes #10297 from thunterdb/12324. --- NOTICE | 9 ++- docs/_layouts/global.html | 35 +++++++--- docs/css/main.css | 137 ++++++++++++++++++++++++++++++++------ docs/js/main.js | 2 +- 4 files changed, 149 insertions(+), 34 deletions(-) diff --git a/NOTICE b/NOTICE index 7f7769f73047..571f8c2fff7f 100644 --- a/NOTICE +++ b/NOTICE @@ -606,4 +606,11 @@ Vis.js uses and redistributes the following third-party libraries: - keycharm https://github.com/AlexDM0/keycharm - The MIT License \ No newline at end of file + The MIT License + +=============================================================================== + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. +=============================================================================== diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 0b5b0cd48af6..3089474c1338 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -1,3 +1,4 @@ + @@ -127,20 +128,32 @@
    {% if page.url contains "/ml" %} - {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} - {% endif %} - + {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} + + +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} + + {{ content }} -
    - {% if page.displayTitle %} -

    {{ page.displayTitle }}

    - {% else %} -

    {{ page.title }}

    - {% endif %} +
    + {% else %} +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} - {{ content }} + {{ content }} -
    +
    + {% endif %} +
    diff --git a/docs/css/main.css b/docs/css/main.css index 356b324d6303..175e8004fca0 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -40,17 +40,14 @@ } body .container-wrapper { - position: absolute; - width: 100%; - display: flex; -} - -body #content { + background-color: #FFF; + color: #1D1F22; + max-width: 1024px; + margin-top: 10px; + margin-left: auto; + margin-right: auto; + border-radius: 15px; position: relative; - - line-height: 1.6; /* Inspired by Github's wiki style */ - background-color: white; - padding-left: 15px; } .title { @@ -101,6 +98,24 @@ a:hover code { max-width: 914px; } +.content { + z-index: 1; + position: relative; + background-color: #FFF; + max-width: 914px; + line-height: 1.6; /* Inspired by Github's wiki style */ + padding-left: 15px; +} + +.content-with-sidebar { + z-index: 1; + position: relative; + background-color: #FFF; + max-width: 914px; + line-height: 1.6; /* Inspired by Github's wiki style */ + padding-left: 30px; +} + .dropdown-menu { /* Remove the default 2px top margin which causes a small gap between the hover trigger area and the popup menu */ @@ -171,24 +186,104 @@ a.anchorjs-link:hover { text-decoration: none; } * The left navigation bar. */ .left-menu-wrapper { - position: absolute; - height: 100%; - - width: 256px; - margin-top: -20px; - padding-top: 20px; + margin-left: 0px; + margin-right: 0px; background-color: #F0F8FC; + border-top-width: 0px; + border-left-width: 0px; + border-bottom-width: 0px; + margin-top: 0px; + width: 210px; + float: left; + position: absolute; } .left-menu { - position: fixed; - max-width: 350px; - - padding-right: 10px; - width: 256px; + padding: 0px; + width: 199px; } .left-menu h3 { margin-left: 10px; line-height: 30px; +} + +/** + * The collapsing button for the navigation bar. + */ +.nav-trigger { + position: fixed; + clip: rect(0, 0, 0, 0); +} + +.nav-trigger + label:after { + content: '»'; +} + +label { + z-index: 10; +} + +label[for="nav-trigger"] { + position: fixed; + margin-left: 0px; + padding-top: 100px; + padding-left: 5px; + width: 10px; + height: 80%; + cursor: pointer; + background-size: contain; + background-color: #D4F0FF; +} + +label[for="nav-trigger"]:hover { + background-color: #BEE9FF; +} + +.nav-trigger:checked + label { + margin-left: 200px; +} + +.nav-trigger:checked + label:after { + content: '«'; +} + +.nav-trigger:checked ~ div.content-with-sidebar { + margin-left: 200px; +} + +.nav-trigger + label, div.content-with-sidebar { + transition: left 0.4s; +} + +/** + * Rules to collapse the menu automatically when the screen becomes too thin. + */ + +@media all and (max-width: 780px) { + + div.content-with-sidebar { + margin-left: 200px; + } + .nav-trigger + label:after { + content: '«'; + } + label[for="nav-trigger"] { + margin-left: 200px; + } + + .nav-trigger:checked + label { + margin-left: 0px; + } + .nav-trigger:checked + label:after { + content: '»'; + } + .nav-trigger:checked ~ div.content-with-sidebar { + margin-left: 0px; + } + + div.container-index { + margin-left: -215px; + } + } \ No newline at end of file diff --git a/docs/js/main.js b/docs/js/main.js index f5d66b16f7b2..2329eb8327dd 100755 --- a/docs/js/main.js +++ b/docs/js/main.js @@ -83,7 +83,7 @@ $(function() { // Display anchor links when hovering over headers. For documentation of the // configuration options, see the AnchorJS documentation. anchors.options = { - placement: 'left' + placement: 'right' }; anchors.add(); From 54c512ba906edfc25b8081ad67498e99d884452b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 16 Dec 2015 10:22:48 -0800 Subject: [PATCH 1789/1824] [SPARK-8745] [SQL] remove GenerateProjection cc rxin Author: Davies Liu Closes #10316 from davies/remove_generate_projection. --- .../codegen/GenerateProjection.scala | 238 ------------------ .../codegen/GenerateSafeProjection.scala | 4 + .../expressions/CodeGenerationSuite.scala | 5 +- .../expressions/ExpressionEvalHelper.scala | 39 +-- .../expressions/MathFunctionsSuite.scala | 4 +- .../CodegenExpressionCachingSuite.scala | 18 -- .../sql/execution/local/ExpandNode.scala | 4 +- .../spark/sql/execution/local/LocalNode.scala | 18 -- 8 files changed, 11 insertions(+), 319 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala deleted file mode 100644 index f229f2000d8e..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -/** - * Java can not access Projection (in package object) - */ -abstract class BaseProjection extends Projection {} - -abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow - -/** - * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input - * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] - * object is custom generated based on the output types of the [[Expression]] to avoid boxing of - * primitive values. - */ -object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - - protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer.execute) - - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) - - // Make Mutablility optional... - protected def create(expressions: Seq[Expression]): Projection = { - val ctx = newCodeGenContext() - val columns = expressions.zipWithIndex.map { - case (e, i) => - s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" - }.mkString("\n") - - val initColumns = expressions.zipWithIndex.map { - case (e, i) => - val eval = e.gen(ctx) - s""" - { - // column$i - ${eval.code} - nullBits[$i] = ${eval.isNull}; - if (!${eval.isNull}) { - c$i = ${eval.value}; - } - } - """ - }.mkString("\n") - - val getCases = (0 until expressions.size).map { i => - s"case $i: return c$i;" - }.mkString("\n") - - val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" - }.mkString("\n") - - val specificAccessorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: return c$i;") - case _ => None - }.mkString("\n") - if (cases.length > 0) { - val getter = "get" + ctx.primitiveTypeName(jt) - s""" - public $jt $getter(int i) { - if (isNullAt(i)) { - return ${ctx.defaultValue(jt)}; - } - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i - + " in $getter"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val specificMutatorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: { c$i = value; return; }") - case _ => None - }.mkString("\n") - if (cases.length > 0) { - val setter = "set" + ctx.primitiveTypeName(jt) - s""" - public void $setter(int i, $jt value) { - nullBits[i] = false; - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i + - " in $setter}"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val hashValues = expressions.zipWithIndex.map { case (e, i) => - val col = s"c$i" - val nonNull = e.dataType match { - case BooleanType => s"$col ? 0 : 1" - case ByteType | ShortType | IntegerType | DateType => s"$col" - case LongType | TimestampType => s"$col ^ ($col >>> 32)" - case FloatType => s"Float.floatToIntBits($col)" - case DoubleType => - s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" - case BinaryType => s"java.util.Arrays.hashCode($col)" - case _ => s"$col.hashCode()" - } - s"isNullAt($i) ? 0 : ($nonNull)" - } - - val hashUpdates: String = hashValues.map( v => - s""" - result *= 37; result += $v;""" - ).mkString("\n") - - val columnChecks = expressions.zipWithIndex.map { case (e, i) => - s""" - if (nullBits[$i] != row.nullBits[$i] || - (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) { - return false; - } - """ - }.mkString("\n") - - val copyColumns = expressions.zipWithIndex.map { case (e, i) => - s"""if (!nullBits[$i]) arr[$i] = c$i;""" - }.mkString("\n") - - val code = s""" - public SpecificProjection generate($exprType[] expr) { - return new SpecificProjection(expr); - } - - class SpecificProjection extends ${classOf[BaseProjection].getName} { - private $exprType[] expressions; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} - - public SpecificProjection($exprType[] expr) { - expressions = expr; - ${initMutableStates(ctx)} - } - - public java.lang.Object apply(java.lang.Object r) { - // GenerateProjection does not work with UnsafeRows. - assert(!(r instanceof ${classOf[UnsafeRow].getName})); - return new SpecificRow((InternalRow) r); - } - - final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { - - $columns - - public SpecificRow(InternalRow ${ctx.INPUT_ROW}) { - $initColumns - } - - public int numFields() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } - - public java.lang.Object genericGet(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases - } - return null; - } - public void update(int i, java.lang.Object value) { - if (value == null) { - setNullAt(i); - return; - } - nullBits[i] = false; - switch (i) { - $updateCases - } - } - $specificAccessorFunctions - $specificMutatorFunctions - - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - - public boolean equals(java.lang.Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; - } - return super.equals(other); - } - - public InternalRow copy() { - java.lang.Object[] arr = new java.lang.Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); - } - } - } - """ - - logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" + - CodeFormatter.format(code)) - - compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b7926bda3de1..13634b69457a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProjection extends Projection {} /** * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index cd2ef7dcd0cd..0c42e2fc7c5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Row, RandomDataGenerator} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -38,7 +38,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val futures = (1 to 20).map { _ => future { GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) - GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 465f7d08aa14..f869a96edb1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -43,7 +43,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) - checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expression.dataType)) { checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) } @@ -120,42 +119,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - protected def checkEvaluationWithGeneratedProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { - - val plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) - - val actual = plan(inputRow) - val expectedRow = InternalRow(expected) - - // We reimplement hashCode in generated `SpecificRow`, make sure it's consistent with our - // interpreted version. - if (actual.hashCode() != expectedRow.hashCode()) { - val ctx = new CodeGenContext - val evaluated = expression.gen(ctx) - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |Expressions: $expression - |Code: $evaluated - """.stripMargin) - } - - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail("Incorrect Evaluation in codegen mode: " + - s"$expression, actual: $actual, expected: $expectedRow$input") - } - if (actual.copy() != expectedRow) { - fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") - } - } - protected def checkEvalutionWithUnsafeProjection( expression: Expression, expected: Any, @@ -202,7 +165,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithOptimization(expression, expected) var plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 88ed9fdd6465..4ad65db0977c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 2d3f98dbbd3d..c9616cdb26c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -34,12 +34,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance.apply(null).getBoolean(0) === false) } - test("GenerateProjection should initialize expressions") { - val expr = And(NondeterministicExpression(), NondeterministicExpression()) - val instance = GenerateProjection.generate(Seq(expr)) - assert(instance.apply(null).getBoolean(0) === false) - } - test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr))() @@ -64,18 +58,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateProjection should not share expression instances") { - val expr1 = MutableExpression() - val instance1 = GenerateProjection.generate(Seq(expr1)) - assert(instance1.apply(null).getBoolean(0) === false) - - val expr2 = MutableExpression() - expr2.mutableState = true - val instance2 = GenerateProjection.generate(Seq(expr2)) - assert(instance1.apply(null).getBoolean(0) === false) - assert(instance2.apply(null).getBoolean(0) === true) - } - test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateMutableProjection.generate(Seq(expr1))() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala index 2aff156d18b5..85111bd6d1c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection} +import org.apache.spark.sql.catalyst.expressions._ case class ExpandNode( conf: SQLConf, @@ -36,7 +36,7 @@ case class ExpandNode( override def open(): Unit = { child.open() - groups = projections.map(ee => newProjection(ee, child.output)).toArray + groups = projections.map(ee => newMutableProjection(ee, child.output)()).toArray idx = groups.length } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index d3381eac91d4..6a882c9234df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -103,24 +103,6 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin result } - protected def newProjection( - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema") - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } - protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { From 2eb5af5f0d3c424dc617bb1a18dd0210ea9ba0bc Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 16 Dec 2015 10:32:32 -0800 Subject: [PATCH 1790/1824] [SPARK-12318][SPARKR] Save mode in SparkR should be error by default shivaram Please help review. Author: Jeff Zhang Closes #10290 from zjffdu/SPARK-12318. --- R/pkg/R/DataFrame.R | 10 +++++----- docs/sparkr.md | 9 ++++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 764597d1e32b..380a13fe2b7c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1886,7 +1886,7 @@ setMethod("except", #' @param df A SparkSQL DataFrame #' @param path A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) #' #' @family DataFrame functions #' @rdname write.df @@ -1903,7 +1903,7 @@ setMethod("except", #' } setMethod("write.df", signature(df = "DataFrame", path = "character"), - function(df, path, source = NULL, mode = "append", ...){ + function(df, path, source = NULL, mode = "error", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", @@ -1928,7 +1928,7 @@ setMethod("write.df", #' @export setMethod("saveDF", signature(df = "DataFrame", path = "character"), - function(df, path, source = NULL, mode = "append", ...){ + function(df, path, source = NULL, mode = "error", ...){ write.df(df, path, source, mode, ...) }) @@ -1951,7 +1951,7 @@ setMethod("saveDF", #' @param df A SparkSQL DataFrame #' @param tableName A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) #' #' @family DataFrame functions #' @rdname saveAsTable @@ -1968,7 +1968,7 @@ setMethod("saveDF", setMethod("saveAsTable", signature(df = "DataFrame", tableName = "character", source = "character", mode = "character"), - function(df, tableName, source = NULL, mode="append", ...){ + function(df, tableName, source = NULL, mode="error", ...){ if (is.null(source)) { sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", diff --git a/docs/sparkr.md b/docs/sparkr.md index 01148786b79d..9ddd2eda3fe8 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -148,7 +148,7 @@ printSchema(people)
    The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example -to a Parquet file using `write.df` +to a Parquet file using `write.df` (Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API)
    {% highlight r %} @@ -387,3 +387,10 @@ The following functions are masked by the SparkR package: Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) + + +# Migration Guide + +## Upgrading From SparkR 1.6 to 1.7 + + - Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API. From 22f6cd86fc2e2d6f6ad2c3aae416732c46ebf1b1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Dec 2015 10:34:30 -0800 Subject: [PATCH 1791/1824] [SPARK-12310][SPARKR] Add write.json and write.parquet for SparkR Add ```write.json``` and ```write.parquet``` for SparkR, and deprecated ```saveAsParquetFile```. Author: Yanbo Liang Closes #10281 from yanboliang/spark-12310. --- R/pkg/NAMESPACE | 4 +- R/pkg/R/DataFrame.R | 51 +++++++++-- R/pkg/R/generics.R | 16 +++- R/pkg/inst/tests/testthat/test_sparkSQL.R | 104 ++++++++++++---------- 4 files changed, 119 insertions(+), 56 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index cab39d68c3f5..ccc01fe16960 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -92,7 +92,9 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", - "write.df") + "write.df", + "write.json", + "write.parquet") exportClasses("Column") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 380a13fe2b7c..0cfa12b997d6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -596,17 +596,44 @@ setMethod("toJSON", RDD(jrdd, serializedMode = "string") }) -#' saveAsParquetFile +#' write.json +#' +#' Save the contents of a DataFrame as a JSON file (one object per line). Files written out +#' with this method can be read back in as a DataFrame using read.json(). +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' +#' @family DataFrame functions +#' @rdname write.json +#' @name write.json +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- read.json(sqlContext, path) +#' write.json(df, "/tmp/sparkr-tmp/") +#'} +setMethod("write.json", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "json", path)) + }) + +#' write.parquet #' #' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out -#' with this method can be read back in as a DataFrame using parquetFile(). +#' with this method can be read back in as a DataFrame using read.parquet(). #' #' @param x A SparkSQL DataFrame #' @param path The directory where the file is saved #' #' @family DataFrame functions -#' @rdname saveAsParquetFile -#' @name saveAsParquetFile +#' @rdname write.parquet +#' @name write.parquet #' @export #' @examples #'\dontrun{ @@ -614,12 +641,24 @@ setMethod("toJSON", #' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" #' df <- read.json(sqlContext, path) -#' saveAsParquetFile(df, "/tmp/sparkr-tmp/") +#' write.parquet(df, "/tmp/sparkr-tmp1/") +#' saveAsParquetFile(df, "/tmp/sparkr-tmp2/") #'} +setMethod("write.parquet", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "parquet", path)) + }) + +#' @rdname write.parquet +#' @name saveAsParquetFile +#' @export setMethod("saveAsParquetFile", signature(x = "DataFrame", path = "character"), function(x, path) { - invisible(callJMethod(x@sdf, "saveAsParquetFile", path)) + .Deprecated("write.parquet") + write.parquet(x, path) }) #' Distinct diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c383e6e78b8b..62be2ddc8f52 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -519,10 +519,6 @@ setGeneric("sample_frac", #' @export setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) -#' @rdname saveAsParquetFile -#' @export -setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) - #' @rdname saveAsTable #' @export setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { @@ -541,6 +537,18 @@ setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) #' @export setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) +#' @rdname write.json +#' @export +setGeneric("write.json", function(x, path) { standardGeneric("write.json") }) + +#' @rdname write.parquet +#' @export +setGeneric("write.parquet", function(x, path) { standardGeneric("write.parquet") }) + +#' @rdname write.parquet +#' @export +setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) + #' @rdname schema #' @export setGeneric("schema", function(x) { standardGeneric("schema") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 071fd310fd58..135c7576e529 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -371,22 +371,49 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) }) -test_that("read.json()/jsonFile() on a local file returns a DataFrame", { +test_that("read/write json files", { + # Test read.df + df <- read.df(sqlContext, jsonPath, "json") + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + + # Test read.df with a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_is(df1, "DataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Test loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_is(df2, "DataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + # Test read.json df <- read.json(sqlContext, jsonPath) expect_is(df, "DataFrame") expect_equal(count(df), 3) - # read.json()/jsonFile() works with multiple input paths + + # Test write.df jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".json") write.df(df, jsonPath2, "json", mode="overwrite") - jsonDF1 <- read.json(sqlContext, c(jsonPath, jsonPath2)) + + # Test write.json + jsonPath3 <- tempfile(pattern="jsonPath3", fileext=".json") + write.json(df, jsonPath3) + + # Test read.json()/jsonFile() works with multiple input paths + jsonDF1 <- read.json(sqlContext, c(jsonPath2, jsonPath3)) expect_is(jsonDF1, "DataFrame") expect_equal(count(jsonDF1), 6) # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath, jsonPath2))) + jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath2, jsonPath3))) expect_is(jsonDF2, "DataFrame") expect_equal(count(jsonDF2), 6) unlink(jsonPath2) + unlink(jsonPath3) }) test_that("jsonRDD() on a RDD with json string", { @@ -454,6 +481,9 @@ test_that("insertInto() on a registered table", { expect_equal(count(sql(sqlContext, "select * from table1")), 2) expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") dropTempTable(sqlContext, "table1") + + unlink(jsonPath2) + unlink(parquetPath2) }) test_that("table() returns a new DataFrame", { @@ -848,33 +878,6 @@ test_that("column calculation", { expect_equal(count(df2), 3) }) -test_that("read.df() from json file", { - df <- read.df(sqlContext, jsonPath, "json") - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - - # Check if we can apply a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_is(df1, "DataFrame") - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Run the same with loadDF - df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_is(df2, "DataFrame") - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) -}) - -test_that("write.df() as parquet file", { - df <- read.df(sqlContext, jsonPath, "json") - write.df(df, parquetPath, "parquet", mode="overwrite") - df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_is(df2, "DataFrame") - expect_equal(count(df2), 3) -}) - test_that("test HiveContext", { ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ @@ -895,6 +898,8 @@ test_that("test HiveContext", { df3 <- sql(hiveCtx, "select * from json2") expect_is(df3, "DataFrame") expect_equal(count(df3), 3) + + unlink(jsonPath2) }) test_that("column operators", { @@ -1333,6 +1338,9 @@ test_that("join() and merge() on a DataFrame", { expect_error(merge(df, df3), paste("The following column name: name_y occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.", sep = "")) + + unlink(jsonPath2) + unlink(jsonPath3) }) test_that("toJSON() returns an RDD of the correct values", { @@ -1396,6 +1404,8 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { # Test base::intersect is working expect_equal(length(intersect(1:20, 3:23)), 18) + + unlink(jsonPath2) }) test_that("withColumn() and withColumnRenamed()", { @@ -1440,31 +1450,35 @@ test_that("mutate(), transform(), rename() and names()", { detach(airquality) }) -test_that("write.df() on DataFrame and works with read.parquet", { - df <- read.json(sqlContext, jsonPath) +test_that("read/write Parquet files", { + df <- read.df(sqlContext, jsonPath, "json") + # Test write.df and read.df write.df(df, parquetPath, "parquet", mode="overwrite") - parquetDF <- read.parquet(sqlContext, parquetPath) - expect_is(parquetDF, "DataFrame") - expect_equal(count(df), count(parquetDF)) -}) + df2 <- read.df(sqlContext, parquetPath, "parquet") + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) -test_that("read.parquet()/parquetFile() works with multiple input paths", { - df <- read.json(sqlContext, jsonPath) - write.df(df, parquetPath, "parquet", mode="overwrite") + # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - write.df(df, parquetPath2, "parquet", mode="overwrite") - parquetDF <- read.parquet(sqlContext, c(parquetPath, parquetPath2)) + write.parquet(df, parquetPath2) + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + suppressWarnings(saveAsParquetFile(df, parquetPath3)) + parquetDF <- read.parquet(sqlContext, c(parquetPath2, parquetPath3)) expect_is(parquetDF, "DataFrame") expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(sqlContext, parquetPath, parquetPath2)) + parquetDF2 <- suppressWarnings(parquetFile(sqlContext, parquetPath2, parquetPath3)) expect_is(parquetDF2, "DataFrame") expect_equal(count(parquetDF2), count(df) * 2) # Test if varargs works with variables saveMode <- "overwrite" mergeSchema <- "true" - parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - write.df(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) + parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) + + unlink(parquetPath2) + unlink(parquetPath3) + unlink(parquetPath4) }) test_that("describe() and summarize() on a DataFrame", { From 26d70bd2b42617ff731b6e9e6d77933b38597ebe Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 16 Dec 2015 10:43:45 -0800 Subject: [PATCH 1792/1824] [SPARK-12215][ML][DOC] User guide section for KMeans in spark.ml cc jkbradley Author: Yu ISHIKAWA Closes #10244 from yu-iskw/SPARK-12215. --- docs/ml-clustering.md | 71 +++++++++++++++++++ .../spark/examples/ml/JavaKMeansExample.java | 8 ++- .../spark/examples/ml/KMeansExample.scala | 49 ++++++------- 3 files changed, 100 insertions(+), 28 deletions(-) diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index a59f7e3005a3..440c455cd077 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -11,6 +11,77 @@ In this section, we introduce the pipeline API for [clustering in mllib](mllib-c * This will become a table of contents (this text will be scraped). {:toc} +## K-means + +[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +most commonly used clustering algorithms that clusters the data points into a +predefined number of clusters. The MLlib implementation includes a parallelized +variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method +called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). + +`KMeans` is implemented as an `Estimator` and generates a `KMeansModel` as the base model. + +### Input Columns + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    predictionColInt"prediction"Predicted cluster center
    + +### Example + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.KMeans) for more details. + +{% include_example scala/org/apache/spark/examples/ml/KMeansExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %} +
    + +
    + + ## Latent Dirichlet allocation (LDA) `LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index 47665ff2b1f3..96481d882a5d 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -23,6 +23,9 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +// $example on$ import org.apache.spark.ml.clustering.KMeansModel; import org.apache.spark.ml.clustering.KMeans; import org.apache.spark.mllib.linalg.Vector; @@ -30,11 +33,10 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +// $example off$ /** @@ -74,6 +76,7 @@ public static void main(String[] args) { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(jsc); + // $example on$ // Loads data JavaRDD points = jsc.textFile(inputFile).map(new ParsePoint()); StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; @@ -91,6 +94,7 @@ public static void main(String[] args) { for (Vector center: centers) { System.out.println(center); } + // $example off$ jsc.stop(); } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala index 5ce38462d118..af90652b55a1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -17,57 +17,54 @@ package org.apache.spark.examples.ml -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -import org.apache.spark.ml.clustering.KMeans -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} +// scalastyle:off println +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} /** * An example demonstrating a k-means clustering. * Run with * {{{ - * bin/run-example ml.KMeansExample + * bin/run-example ml.KMeansExample * }}} */ object KMeansExample { - final val FEATURES_COL = "features" - def main(args: Array[String]): Unit = { - if (args.length != 2) { - // scalastyle:off println - System.err.println("Usage: ml.KMeansExample ") - // scalastyle:on println - System.exit(1) - } - val input = args(0) - val k = args(1).toInt - // Creates a Spark context and a SQL context val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) + // $example on$ + // Crates a DataFrame + val dataset: DataFrame = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(0.0, 0.0, 0.0)), + (2, Vectors.dense(0.1, 0.1, 0.1)), + (3, Vectors.dense(0.2, 0.2, 0.2)), + (4, Vectors.dense(9.0, 9.0, 9.0)), + (5, Vectors.dense(9.1, 9.1, 9.1)), + (6, Vectors.dense(9.2, 9.2, 9.2)) + )).toDF("id", "features") // Trains a k-means model val kmeans = new KMeans() - .setK(k) - .setFeaturesCol(FEATURES_COL) + .setK(2) + .setFeaturesCol("features") + .setPredictionCol("prediction") val model = kmeans.fit(dataset) // Shows the result - // scalastyle:off println println("Final Centers: ") model.clusterCenters.foreach(println) - // scalastyle:on println + // $example off$ sc.stop() } } +// scalastyle:on println From ad8c1f0b840284d05da737fb2cc5ebf8848f4490 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Wed, 16 Dec 2015 10:54:15 -0800 Subject: [PATCH 1793/1824] [SPARK-12345][MESOS] Filter SPARK_HOME when submitting Spark jobs with Mesos cluster mode. SPARK_HOME is now causing problem with Mesos cluster mode since spark-submit script has been changed recently to take precendence when running spark-class scripts to look in SPARK_HOME if it's defined. We should skip passing SPARK_HOME from the Spark client in cluster mode with Mesos, since Mesos shouldn't use this configuration but should use spark.executor.home instead. Author: Timothy Chen Closes #10332 from tnachen/scheduler_ui. --- .../apache/spark/deploy/rest/mesos/MesosRestServer.scala | 7 ++++++- .../scheduler/cluster/mesos/MesosSchedulerUtils.scala | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 868cc35d06ef..24510db2bd0b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -94,7 +94,12 @@ private[mesos] class MesosSubmitRequestServlet( val driverMemory = sparkProperties.get("spark.driver.memory") val driverCores = sparkProperties.get("spark.driver.cores") val appArgs = request.appArgs - val environmentVariables = request.environmentVariables + // We don't want to pass down SPARK_HOME when launching Spark apps + // with Mesos cluster mode since it's populated by default on the client and it will + // cause spark-submit script to look for files in SPARK_HOME instead. + // We only need the ability to specify where to find spark-submit script + // which user can user spark.executor.home or spark.home configurations. + val environmentVariables = request.environmentVariables.filter(!_.equals("SPARK_HOME")) val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 721861fbbc51..573355ba5813 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.Utils /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper - * methods and Mesos scheduler will use. + * methods the Mesos scheduler will use. */ private[mesos] trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered From 7b6dc29d0ebbfb3bb941130f8542120b6bc3e234 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 16 Dec 2015 10:55:42 -0800 Subject: [PATCH 1794/1824] [SPARK-6518][MLLIB][EXAMPLE][DOC] Add example code and user guide for bisecting k-means This PR includes only an example code in order to finish it quickly. I'll send another PR for the docs soon. Author: Yu ISHIKAWA Closes #9952 from yu-iskw/SPARK-6518. --- docs/mllib-clustering.md | 35 ++++++++++ docs/mllib-guide.md | 1 + .../mllib/JavaBisectingKMeansExample.java | 69 +++++++++++++++++++ .../mllib/BisectingKMeansExample.scala | 60 ++++++++++++++++ 4 files changed, 165 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 48d64cd402b1..93cd0c1c61ae 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -718,6 +718,41 @@ sameModel = LDAModel.load(sc, "myModelPath")
    +## Bisecting k-means + +Bisecting K-means can often be much faster than regular K-means, but it will generally produce a different clustering. + +Bisecting k-means is a kind of [hierarchical clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering). +Hierarchical clustering is one of the most commonly used method of cluster analysis which seeks to build a hierarchy of clusters. +Strategies for hierarchical clustering generally fall into two types: + +- Agglomerative: This is a "bottom up" approach: each observation starts in its own cluster, and pairs of clusters are merged as one moves up the hierarchy. +- Divisive: This is a "top down" approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy. + +Bisecting k-means algorithm is a kind of divisive algorithms. +The implementation in MLlib has the following parameters: + +* *k*: the desired number of leaf clusters (default: 4). The actual number could be smaller if there are no divisible leaf clusters. +* *maxIterations*: the max number of k-means iterations to split clusters (default: 20) +* *minDivisibleClusterSize*: the minimum number of points (if >= 1.0) or the minimum proportion of points (if < 1.0) of a divisible cluster (default: 1) +* *seed*: a random seed (default: hash value of the class name) + +**Examples** + +
    +
    +Refer to the [`BisectingKMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.BisectingKMeans) and [`BisectingKMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.BisectingKMeansModel) for details on the API. + +{% include_example scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala %} +
    + +
    +Refer to the [`BisectingKMeans` Java docs](api/java/org/apache/spark/mllib/clustering/BisectingKMeans.html) and [`BisectingKMeansModel` Java docs](api/java/org/apache/spark/mllib/clustering/BisectingKMeansModel.html) for details on the API. + +{% include_example java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java %} +
    +
    + ## Streaming k-means When data arrive in a stream, we may want to estimate clusters dynamically, diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 7fef6b5c61f9..680ed4861d9f 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -49,6 +49,7 @@ We list major functionality from both below, with links to detailed guides. * [Gaussian mixture](mllib-clustering.html#gaussian-mixture) * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic) * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda) + * [bisecting k-means](mllib-clustering.html#bisecting-kmeans) * [streaming k-means](mllib-clustering.html#streaming-k-means) * [Dimensionality reduction](mllib-dimensionality-reduction.html) * [singular value decomposition (SVD)](mllib-dimensionality-reduction.html#singular-value-decomposition-svd) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java new file mode 100644 index 000000000000..0001500f4fa5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBisectingKMeansExample.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import java.util.ArrayList; + +// $example on$ +import com.google.common.collect.Lists; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.clustering.BisectingKMeans; +import org.apache.spark.mllib.clustering.BisectingKMeansModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +// $example off$ + +/** + * Java example for graph clustering using power iteration clustering (PIC). + */ +public class JavaBisectingKMeansExample { + public static void main(String[] args) { + SparkConf sparkConf = new SparkConf().setAppName("JavaBisectingKMeansExample"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // $example on$ + ArrayList localData = Lists.newArrayList( + Vectors.dense(0.1, 0.1), Vectors.dense(0.3, 0.3), + Vectors.dense(10.1, 10.1), Vectors.dense(10.3, 10.3), + Vectors.dense(20.1, 20.1), Vectors.dense(20.3, 20.3), + Vectors.dense(30.1, 30.1), Vectors.dense(30.3, 30.3) + ); + JavaRDD data = sc.parallelize(localData, 2); + + BisectingKMeans bkm = new BisectingKMeans() + .setK(4); + BisectingKMeansModel model = bkm.run(data); + + System.out.println("Compute Cost: " + model.computeCost(data)); + for (Vector center: model.clusterCenters()) { + System.out.println(""); + } + Vector[] clusterCenters = model.clusterCenters(); + for (int i = 0; i < clusterCenters.length; i++) { + Vector clusterCenter = clusterCenters[i]; + System.out.println("Cluster Center " + i + ": " + clusterCenter); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala new file mode 100644 index 000000000000..3a596cccb87d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +// scalastyle:off println +// $example on$ +import org.apache.spark.mllib.clustering.BisectingKMeans +import org.apache.spark.mllib.linalg.{Vector, Vectors} +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example demonstrating a bisecting k-means clustering in spark.mllib. + * + * Run with + * {{{ + * bin/run-example mllib.BisectingKMeansExample + * }}} + */ +object BisectingKMeansExample { + + def main(args: Array[String]) { + val sparkConf = new SparkConf().setAppName("mllib.BisectingKMeansExample") + val sc = new SparkContext(sparkConf) + + // $example on$ + // Loads and parses data + def parse(line: String): Vector = Vectors.dense(line.split(" ").map(_.toDouble)) + val data = sc.textFile("data/mllib/kmeans_data.txt").map(parse).cache() + + // Clustering the data into 6 clusters by BisectingKMeans. + val bkm = new BisectingKMeans().setK(6) + val model = bkm.run(data) + + // Show the compute cost and the cluster centers + println(s"Compute Cost: ${model.computeCost(data)}") + model.clusterCenters.zipWithIndex.foreach { case (center, idx) => + println(s"Cluster Center ${idx}: ${center}") + } + // $example off$ + + sc.stop() + } +} +// scalastyle:on println From 860dc7f2f8dd01f2562ba83b7af27ba29d91cb62 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Dec 2015 11:05:37 -0800 Subject: [PATCH 1795/1824] [SPARK-9694][ML] Add random seed Param to Scala CrossValidator Add random seed Param to Scala CrossValidator Author: Yanbo Liang Closes #9108 from yanboliang/spark-9694. --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 11 ++++++++--- .../scala/org/apache/spark/mllib/util/MLUtils.scala | 8 ++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5c09f1aaff80..40f8857fc586 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -29,8 +29,9 @@ import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -39,7 +40,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends ValidatorParams { +private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 @@ -85,6 +86,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("1.2.0") def setNumFolds(value: Int): this.type = set(numFolds, value) + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + @Since("1.4.0") override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema @@ -95,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 414ea99cfd8c..4c9151f0cb4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -265,6 +265,14 @@ object MLUtils { */ @Since("1.0.0") def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { + kFold(rdd, numFolds, seed.toLong) + } + + /** + * Version of [[kFold()]] taking a Long seed. + */ + @Since("2.0.0") + def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Long): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, From d252b2d544a75f6c5523be3492494955050acf50 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Dec 2015 11:07:54 -0800 Subject: [PATCH 1796/1824] [SPARK-12309][ML] Use sqlContext from MLlibTestSparkContext for spark.ml test suites Use ```sqlContext``` from ```MLlibTestSparkContext``` rather than creating new one for spark.ml test suites. I have checked thoroughly and found there are four test cases need to update. cc mengxr jkbradley Author: Yanbo Liang Closes #10279 from yanboliang/spark-12309. --- .../scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala | 4 +--- .../scala/org/apache/spark/ml/feature/NormalizerSuite.scala | 3 +-- .../scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala | 4 +--- mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala | 2 +- .../org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 3 +-- 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 09183fe65b72..035bfc07b684 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -21,13 +21,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { - val sqlContext = new SQLContext(sc) - val data = Array( Vectors.dense(1, 0, Long.MinValue), Vectors.dense(2, 0, 0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index de3d438ce83b..468833901995 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -61,7 +61,6 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.sparse(3, Seq()) ) - val sqlContext = new SQLContext(sc) dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 74706a23e093..8acc3369c489 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -54,8 +54,6 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De } test("Test vector slicer") { - val sqlContext = new SQLContext(sc) - val data = Array( Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))), Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 460849c79f04..4e2d0e93bd41 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -42,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val sqlContext = new SQLContext(data.sparkContext) + val sqlContext = SQLContext.getOrCreate(data.sparkContext) import sqlContext.implicits._ val df = data.toDF() val numFeatures = data.first().features.size diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index dd6366050c02..d281084f913c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType class CrossValidatorSuite @@ -39,7 +39,6 @@ class CrossValidatorSuite override def beforeAll(): Unit = { super.beforeAll() - val sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } From 6a880afa831348b413ba95b98ff089377b950666 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 16 Dec 2015 11:29:47 -0800 Subject: [PATCH 1797/1824] [SPARK-12361][PYSPARK][TESTS] Should set PYSPARK_DRIVER_PYTHON before Python tests Although this patch still doesn't solve the issue why the return code is 0 (see JIRA description), it resolves the issue of python version mismatch. Author: Jeff Zhang Closes #10322 from zjffdu/SPARK-12361. --- python/run-tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/run-tests.py b/python/run-tests.py index f5857f8c6221..ee73eb1506ca 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -56,7 +56,8 @@ def print_red(text): def run_individual_python_test(test_name, pyspark_python): env = dict(os.environ) - env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}) + env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python), + 'PYSPARK_DRIVER_PYTHON': which(pyspark_python)}) LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: From 8148cc7a5c9f52c82c2eb7652d9aeba85e72d406 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 16 Dec 2015 11:53:04 -0800 Subject: [PATCH 1798/1824] [SPARK-11608][MLLIB][DOC] Added migration guide for MLlib 1.6 No known breaking changes, but some deprecations and changes of behavior. CC: mengxr Author: Joseph K. Bradley Closes #10235 from jkbradley/mllib-guide-update-1.6. --- docs/mllib-guide.md | 38 ++++++++++++++++++++-------------- docs/mllib-migration-guides.md | 19 +++++++++++++++++ 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 680ed4861d9f..7ef91a178ccd 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -74,7 +74,7 @@ We list major functionality from both below, with links to detailed guides. * [Advanced topics](ml-advanced.html) Some techniques are not available yet in spark.ml, most notably dimensionality reduction -Users can seemlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. +Users can seamlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. # Dependencies @@ -101,24 +101,32 @@ MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, and the migration guide below will explain all changes between releases. -## From 1.4 to 1.5 +## From 1.5 to 1.6 -In the `spark.mllib` package, there are no break API changes but several behavior changes: +There are no breaking API changes in the `spark.mllib` or `spark.ml` packages, but there are +deprecations and changes of behavior. -* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): - `RegressionMetrics.explainedVariance` returns the average regression sum of squares. -* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become - sorted. -* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default - convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. +Deprecations: -In the `spark.ml` package, there exists one break API change and one behavior change: +* [SPARK-11358](https://issues.apache.org/jira/browse/SPARK-11358): + In `spark.mllib.clustering.KMeans`, the `runs` parameter has been deprecated. +* [SPARK-10592](https://issues.apache.org/jira/browse/SPARK-10592): + In `spark.ml.classification.LogisticRegressionModel` and + `spark.ml.regression.LinearRegressionModel`, the `weights` field has been deprecated in favor of + the new name `coefficients`. This helps disambiguate from instance (row) "weights" given to + algorithms. -* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed - from `Params.setDefault` due to a - [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). -* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is - added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. +Changes of behavior: + +* [SPARK-7770](https://issues.apache.org/jira/browse/SPARK-7770): + `spark.mllib.tree.GradientBoostedTrees`: `validationTol` has changed semantics in 1.6. + Previously, it was a threshold for absolute change in error. Now, it resembles the behavior of + `GradientDescent`'s `convergenceTol`: For large errors, it uses relative error (relative to the + previous error); for small errors (`< 0.01`), it uses absolute error. +* [SPARK-11069](https://issues.apache.org/jira/browse/SPARK-11069): + `spark.ml.feature.RegexTokenizer`: Previously, it did not convert strings to lowercase before + tokenizing. Now, it converts to lowercase by default, with an option not to. This matches the + behavior of the simpler `Tokenizer` transformer. ## Previous Spark versions diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 73e4fddf67fc..f3daef2dbadb 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -7,6 +7,25 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.4 to 1.5 + +In the `spark.mllib` package, there are no breaking API changes but several behavior changes: + +* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): + `RegressionMetrics.explainedVariance` returns the average regression sum of squares. +* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become + sorted. +* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default + convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. + +In the `spark.ml` package, there exists one breaking API change and one behavior change: + +* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed + from `Params.setDefault` due to a + [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). +* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is + added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. + ## From 1.3 to 1.4 In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: From 1a8b2a17db7ab7a213d553079b83274aeebba86f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Dec 2015 12:59:22 -0800 Subject: [PATCH 1799/1824] [SPARK-12364][ML][SPARKR] Add ML example for SparkR We have DataFrame example for SparkR, we also need to add ML example under ```examples/src/main/r```. cc mengxr jkbradley shivaram Author: Yanbo Liang Closes #10324 from yanboliang/spark-12364. --- examples/src/main/r/ml.R | 54 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 examples/src/main/r/ml.R diff --git a/examples/src/main/r/ml.R b/examples/src/main/r/ml.R new file mode 100644 index 000000000000..a0c903939cbb --- /dev/null +++ b/examples/src/main/r/ml.R @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/sparkR examples/src/main/r/ml.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkContext and SQLContext +sc <- sparkR.init(appName="SparkR-ML-example") +sqlContext <- sparkRSQL.init(sc) + +# Train GLM of family 'gaussian' +training1 <- suppressWarnings(createDataFrame(sqlContext, iris)) +test1 <- training1 +model1 <- glm(Sepal_Length ~ Sepal_Width + Species, training1, family = "gaussian") + +# Model summary +summary(model1) + +# Prediction +predictions1 <- predict(model1, test1) +head(select(predictions1, "Sepal_Length", "prediction")) + +# Train GLM of family 'binomial' +training2 <- filter(training1, training1$Species != "setosa") +test2 <- training2 +model2 <- glm(Species ~ Sepal_Length + Sepal_Width, data = training2, family = "binomial") + +# Model summary +summary(model2) + +# Prediction (Currently the output of prediction for binomial GLM is the indexed label, +# we need to transform back to the original string label later) +predictions2 <- predict(model2, test2) +head(select(predictions2, "Species", "prediction")) + +# Stop the SparkContext now +sparkR.stop() From a783a8ed49814a09fde653433a3d6de398ddf888 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 16 Dec 2015 13:18:56 -0800 Subject: [PATCH 1800/1824] [SPARK-12320][SQL] throw exception if the number of fields does not line up for Tuple encoder Author: Wenchen Fan Closes #10293 from cloud-fan/err-msg. --- .../spark/sql/catalyst/dsl/package.scala | 3 +- .../catalyst/encoders/ExpressionEncoder.scala | 36 ++++++++++- .../expressions/complexTypeExtractors.scala | 10 ++-- .../encoders/EncoderResolutionSuite.scala | 60 ++++++++++++++++--- .../expressions/ComplexTypeSuite.scala | 2 +- 5 files changed, 93 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index e50971173c49..8102c93c6f10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -227,9 +227,10 @@ package object dsl { AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ - def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) def struct(structType: StructType): AttributeReference = AttributeReference(s, structType, nullable = true)() + def struct(attrs: AttributeReference*): AttributeReference = + struct(StructType.fromAttributes(attrs)) } implicit class DslAttribute(a: AttributeReference) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 363178b0e21a..7a4401cf5810 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -244,9 +244,41 @@ case class ExpressionEncoder[T]( def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) + def fail(st: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + + " - Target schema: " + this.schema.simpleString) + } + + var maxOrdinal = -1 + fromRowExpression.foreach { + case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal + case _ => + } + if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { + fail(StructType.fromAttributes(schema), maxOrdinal) + } + val unbound = fromRowExpression transform { - case b: BoundReference => positionToAttribute(b.ordinal) + case b: BoundReference => schema(b.ordinal) + } + + val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] + unbound.foreach { + case g: GetStructField => + val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) + if (maxOrdinal < g.ordinal) { + exprToMaxOrdinal.update(g.child, g.ordinal) + } + case _ => + } + exprToMaxOrdinal.foreach { + case (expr, maxOrdinal) => + val schema = expr.dataType.asInstanceOf[StructType] + if (maxOrdinal != schema.length - 1) { + fail(schema, maxOrdinal) + } } val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 10ce10aaf6da..58f6a7ec8a5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -104,14 +104,14 @@ object ExtractValue { case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) + private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] - override def dataType: DataType = field.dataType - override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${name.getOrElse(field.name)}" + override def dataType: DataType = childSchema(ordinal).dataType + override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable + override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow].get(ordinal, field.dataType) + input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 0289988342e7..815a03f7c1a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -64,22 +64,21 @@ class EncoderResolutionSuite extends PlanTest { val innerCls = classOf[StringLongClass] val cls = classOf[ComplexClass] - val structType = new StructType().add("a", IntegerType).add("b", LongType) - val attrs = Seq('a.int, 'b.struct(structType)) + val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( cls, Seq( 'a.int.cast(LongType), If( - 'b.struct(structType).isNull, + 'b.struct('a.int, 'b.long).isNull, Literal.create(null, ObjectType(innerCls)), NewInstance( innerCls, Seq( toExternalString( - GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)), - GetStructField('b.struct(structType), 1, Some("b"))), + GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), + GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))), false, ObjectType(innerCls)) )), @@ -94,8 +93,7 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[Long]) val cls = classOf[StringLongClass] - val structType = new StructType().add("a", StringType).add("b", ByteType) - val attrs = Seq('a.struct(structType), 'b.int) + val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( classOf[Tuple2[_, _]], @@ -103,8 +101,8 @@ class EncoderResolutionSuite extends PlanTest { NewInstance( cls, Seq( - toExternalString(GetStructField('a.struct(structType), 0, Some("a"))), - GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)), + toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), + GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)), false, ObjectType(cls)), 'b.int.cast(LongType)), @@ -113,6 +111,50 @@ class EncoderResolutionSuite extends PlanTest { compareExpressions(fromRowExpr, expected) } + test("the real number of fields doesn't match encoder schema: tuple encoder") { + val encoder = ExpressionEncoder[(String, Long)] + + { + val attrs = Seq('a.string, 'b.long, 'c.int) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + + { + val attrs = Seq('a.string) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + } + + test("the real number of fields doesn't match encoder schema: nested tuple encoder") { + val encoder = ExpressionEncoder[(String, (Long, String))] + + { + val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + + { + val attrs = Seq('a.string, 'b.struct('x.long)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + } + private def toExternalString(e: Expression): Expression = { Invoke(e, "toString", ObjectType(classOf[String]), Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 62fd47234b33..9f1b19253e7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -165,7 +165,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { "b", create_row(Map("a" -> "b"))) checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), "b", create_row(Seq("a", "b"))) - checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")), 1, create_row(create_row(1))) } From edf65cd961b913ef54104770630a50fd4b120b4b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 16 Dec 2015 13:22:34 -0800 Subject: [PATCH 1801/1824] [SPARK-12164][SQL] Decode the encoded values and then display Based on the suggestions from marmbrus cloud-fan in https://github.com/apache/spark/pull/10165 , this PR is to print the decoded values(user objects) in `Dataset.show` ```scala implicit val kryoEncoder = Encoders.kryo[KryoClassData] val ds = Seq(KryoClassData("a", 1), KryoClassData("b", 2), KryoClassData("c", 3)).toDS() ds.show(20, false); ``` The current output is like ``` +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |value | +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |[1, 0, 111, 114, 103, 46, 97, 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, 115, 113, 108, 46, 75, 114, 121, 111, 67, 108, 97, 115, 115, 68, 97, 116, -31, 1, 1, -126, 97, 2]| |[1, 0, 111, 114, 103, 46, 97, 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, 115, 113, 108, 46, 75, 114, 121, 111, 67, 108, 97, 115, 115, 68, 97, 116, -31, 1, 1, -126, 98, 4]| |[1, 0, 111, 114, 103, 46, 97, 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, 115, 113, 108, 46, 75, 114, 121, 111, 67, 108, 97, 115, 115, 68, 97, 116, -31, 1, 1, -126, 99, 6]| +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ ``` After the fix, it will be like the below if and only if the users override the `toString` function in the class `KryoClassData` ```scala override def toString: String = s"KryoClassData($a, $b)" ``` ``` +-------------------+ |value | +-------------------+ |KryoClassData(a, 1)| |KryoClassData(b, 2)| |KryoClassData(c, 3)| +-------------------+ ``` If users do not override the `toString` function, the results will be like ``` +---------------------------------------+ |value | +---------------------------------------+ |org.apache.spark.sql.KryoClassData68ef| |org.apache.spark.sql.KryoClassData6915| |org.apache.spark.sql.KryoClassData693b| +---------------------------------------+ ``` Question: Should we add another optional parameter in the function `show`? It will decide if the function `show` will display the hex values or the object values? Author: gatorsmile Closes #10215 from gatorsmile/showDecodedValue. --- .../org/apache/spark/sql/DataFrame.scala | 50 +------------- .../scala/org/apache/spark/sql/Dataset.scala | 37 ++++++++++- .../spark/sql/execution/Queryable.scala | 65 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 15 +++++ .../org/apache/spark/sql/DatasetSuite.scala | 14 ++++ 5 files changed, 133 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 497bd4826677..6250e952169d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -165,13 +165,11 @@ class DataFrame private[sql]( * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { val numRows = _numRows.max(0) - val sb = new StringBuilder val takeResult = take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) - val numCols = schema.fieldNames.length // For array values, replace Seq and Array with square brackets // For cells that are beyond 20 characters, replace it with the first 17 and "..." @@ -179,6 +177,7 @@ class DataFrame private[sql]( row.toSeq.map { cell => val str = cell match { case null => "null" + case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") case array: Array[_] => array.mkString("[", ", ", "]") case seq: Seq[_] => seq.mkString("[", ", ", "]") case _ => cell.toString @@ -187,50 +186,7 @@ class DataFrame private[sql]( }: Seq[String] } - // Initialise the width of each column to a minimum value of '3' - val colWidths = Array.fill(numCols)(3) - - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - - sb.append(sep) - - // data - rows.tail.map { - _.zipWithIndex.map { case (cell, i) => - if (truncate) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - - sb.append(sep) - - // For Data that has more than "numRows" records - if (hasMoreData) { - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") - } - - sb.toString() + formatString ( rows, numRows, hasMoreData, truncate ) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dc69822e9290..79b4244ac0cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -225,7 +225,42 @@ class Dataset[T] private[sql]( * * @since 1.6.0 */ - def show(numRows: Int, truncate: Boolean): Unit = toDF().show(numRows, truncate) + // scalastyle:off println + def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + // scalastyle:on println + + /** + * Compose the string representing rows for output + * @param _numRows Number of rows to show + * @param truncate Whether truncate long strings and align cells right + */ + override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { + val numRows = _numRows.max(0) + val takeResult = take(numRows + 1) + val hasMoreData = takeResult.length > numRows + val data = takeResult.take(numRows) + + // For array values, replace Seq and Array with square brackets + // For cells that are beyond 20 characters, replace it with the first 17 and "..." + val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: (data.map { + case r: Row => r + case tuple: Product => Row.fromTuple(tuple) + case o => Row(o) + } map { row => + row.toSeq.map { cell => + val str = cell match { + case null => "null" + case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") + case array: Array[_] => array.mkString("[", ", ", "]") + case seq: Seq[_] => seq.mkString("[", ", ", "]") + case _ => cell.toString + } + if (truncate && str.length > 20) str.substring(0, 17) + "..." else str + }: Seq[String] + }) + + formatString ( rows, numRows, hasMoreData, truncate ) + } /** * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index f2f5997d1b7c..b397d42612cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import scala.util.control.NonFatal +import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType @@ -42,4 +43,68 @@ private[sql] trait Queryable { def explain(extended: Boolean): Unit def explain(): Unit + + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String + + /** + * Format the string representing rows for output + * @param rows The rows to show + * @param numRows Number of rows to show + * @param hasMoreData Whether some rows are not shown due to the limit + * @param truncate Whether truncate long strings and align cells right + * + */ + private[sql] def formatString ( + rows: Seq[Seq[String]], + numRows: Int, + hasMoreData : Boolean, + truncate: Boolean = true): String = { + val sb = new StringBuilder + val numCols = schema.fieldNames.length + + // Initialise the width of each column to a minimum value of '3' + val colWidths = Array.fill(numCols)(3) + + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() + + // column names + rows.head.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + + sb.append(sep) + + // data + rows.tail.map { + _.zipWithIndex.map { case (cell, i) => + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + + // For Data that has more than "numRows" records + if (hasMoreData) { + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") + } + + sb.toString() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c0bbf73ab118..0644bdaaa35c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -585,6 +585,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: binary") { + val df = Seq( + ("12".getBytes, "ABC.".getBytes), + ("34".getBytes, "12346".getBytes) + ).toDF() + val expectedAnswer = """+-------+----------------+ + || _1| _2| + |+-------+----------------+ + ||[31 32]| [41 42 43 2E]| + ||[33 34]|[31 32 33 34 36]| + |+-------+----------------+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + test("showString: minimum column width") { val df = Seq( (1, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8f8db318261d..f1b6b98dc160 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -426,6 +426,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.toString == "[_1: int, _2: int]") } + test("showString: Kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + + val expectedAnswer = """+-----------+ + || value| + |+-----------+ + ||KryoData(1)| + ||KryoData(2)| + |+-----------+ + |""".stripMargin + assert(ds.showString(10) === expectedAnswer) + } + test("Kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() From 9657ee87888422c5596987fe760b49117a0ea4e2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 16 Dec 2015 13:24:49 -0800 Subject: [PATCH 1802/1824] [SPARK-11677][SQL] ORC filter tests all pass if filters are actually not pushed down. Currently ORC filters are not tested properly. All the tests pass even if the filters are not pushed down or disabled. In this PR, I add some logics for this. Since ORC does not filter record by record fully, this checks the count of the result and if it contains the expected values. Author: hyukjinkwon Closes #9687 from HyukjinKwon/SPARK-11677. --- .../spark/sql/hive/orc/OrcQuerySuite.scala | 53 +++++++++++++------ 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 7efeab528c1d..2156806d21f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -350,28 +350,47 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { dir => withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { import testImplicits._ - val path = dir.getCanonicalPath - sqlContext.range(10).coalesce(1).write.orc(path) + + // For field "a", the first column has odds integers. This is to check the filtered count + // when `isNull` is performed. For Field "b", `isNotNull` of ORC file filters rows + // only when all the values are null (maybe this works differently when the data + // or query is complicated). So, simply here a column only having `null` is added. + val data = (0 until 10).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + val nullValue: Option[String] = None + (maybeInt, nullValue) + } + createDataFrame(data).toDF("a", "b").write.orc(path) val df = sqlContext.read.orc(path) - def checkPredicate(pred: Column, answer: Seq[Long]): Unit = { - checkAnswer(df.where(pred), answer.map(Row(_))) + def checkPredicate(pred: Column, answer: Seq[Row]): Unit = { + val sourceDf = stripSparkFilter(df.where(pred)) + val data = sourceDf.collect().toSet + val expectedData = answer.toSet + + // When a filter is pushed to ORC, ORC can apply it to rows. So, we can check + // the number of rows returned from the ORC to make sure our filter pushdown work. + // A tricky part is, ORC does not process filter rows fully but return some possible + // results. So, this checks if the number of result is less than the original count + // of data, and then checks if it contains the expected data. + val isOrcFiltered = sourceDf.count < 10 && expectedData.subsetOf(data) + assert(isOrcFiltered) } - checkPredicate('id === 5, Seq(5L)) - checkPredicate('id <=> 5, Seq(5L)) - checkPredicate('id < 5, 0L to 4L) - checkPredicate('id <= 5, 0L to 5L) - checkPredicate('id > 5, 6L to 9L) - checkPredicate('id >= 5, 5L to 9L) - checkPredicate('id.isNull, Seq.empty[Long]) - checkPredicate('id.isNotNull, 0L to 9L) - checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L)) - checkPredicate('id > 0 && 'id < 3, 1L to 2L) - checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L)) - checkPredicate(!('id > 3), 0L to 3L) - checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L)) + checkPredicate('a === 5, List(5).map(Row(_, null))) + checkPredicate('a <=> 5, List(5).map(Row(_, null))) + checkPredicate('a < 5, List(1, 3).map(Row(_, null))) + checkPredicate('a <= 5, List(1, 3, 5).map(Row(_, null))) + checkPredicate('a > 5, List(7, 9).map(Row(_, null))) + checkPredicate('a >= 5, List(5, 7, 9).map(Row(_, null))) + checkPredicate('a.isNull, List(null).map(Row(_, null))) + checkPredicate('b.isNotNull, List()) + checkPredicate('a.isin(3, 5, 7), List(3, 5, 7).map(Row(_, null))) + checkPredicate('a > 0 && 'a < 3, List(1).map(Row(_, null))) + checkPredicate('a < 1 || 'a > 8, List(9).map(Row(_, null))) + checkPredicate(!('a > 3), List(1, 3).map(Row(_, null))) + checkPredicate(!('a > 0 && 'a < 3), List(3, 5, 7, 9).map(Row(_, null))) } } } From 3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb Mon Sep 17 00:00:00 2001 From: Martin Menestret Date: Wed, 16 Dec 2015 14:05:35 -0800 Subject: [PATCH 1803/1824] [SPARK-9690][ML][PYTHON] pyspark CrossValidator random seed Extend CrossValidator with HasSeed in PySpark. This PR replaces [https://github.com/apache/spark/pull/7997] CC: yanboliang thunterdb mmenestret Would one of you mind taking a look? Thanks! Author: Joseph K. Bradley Author: Martin MENESTRET Closes #10268 from jkbradley/pyspark-cv-seed. --- python/pyspark/ml/tuning.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 705ee5368575..08f8db57f440 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -19,8 +19,9 @@ import numpy as np from pyspark import since -from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model +from pyspark.ml.param import Params, Param +from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only from pyspark.sql.functions import rand @@ -89,7 +90,7 @@ def build(self): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] -class CrossValidator(Estimator): +class CrossValidator(Estimator, HasSeed): """ K-fold cross validation. @@ -129,9 +130,11 @@ class CrossValidator(Estimator): numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only - def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None) """ super(CrossValidator, self).__init__() #: param for estimator to be cross-validated @@ -151,9 +154,11 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF @keyword_only @since("1.4.0") - def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None): Sets params for cross validator. """ kwargs = self.setParams._input_kwargs @@ -225,9 +230,10 @@ def _fit(self, dataset): numModels = len(epm) eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) + seed = self.getOrDefault(self.seed) h = 1.0 / nFolds randCol = self.uid + "_rand" - df = dataset.select("*", rand(0).alias(randCol)) + df = dataset.select("*", rand(seed).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h From 27b98e99d21a0cc34955337f82a71a18f9220ab2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 16 Dec 2015 15:48:11 -0800 Subject: [PATCH 1804/1824] [SPARK-12380] [PYSPARK] use SQLContext.getOrCreate in mllib MLlib should use SQLContext.getOrCreate() instead of creating new SQLContext. Author: Davies Liu Closes #10338 from davies/create_context. --- python/pyspark/mllib/common.py | 6 +++--- python/pyspark/mllib/evaluation.py | 10 +++++----- python/pyspark/mllib/feature.py | 4 +--- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index a439a488de5c..9fda1b1682f5 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -102,7 +102,7 @@ def _java2py(sc, r, encoding="bytes"): return RDD(jrdd, sc) if clsName == 'DataFrame': - return DataFrame(r, SQLContext(sc)) + return DataFrame(r, SQLContext.getOrCreate(sc)) if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) @@ -125,7 +125,7 @@ def callJavaFunc(sc, func, *args): def callMLlibFunc(name, *args): """ Call API in PythonMLLibAPI """ - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() api = getattr(sc._jvm.PythonMLLibAPI(), name) return callJavaFunc(sc, api, *args) @@ -135,7 +135,7 @@ class JavaModelWrapper(object): Wrapper for the model in JVM """ def __init__(self, java_model): - self._sc = SparkContext._active_spark_context + self._sc = SparkContext.getOrCreate() self._java_model = java_model def __del__(self): diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 8c87ee9df213..22e68ea5b451 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -44,7 +44,7 @@ class BinaryClassificationMetrics(JavaModelWrapper): def __init__(self, scoreAndLabels): sc = scoreAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ StructField("score", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -103,7 +103,7 @@ class RegressionMetrics(JavaModelWrapper): def __init__(self, predictionAndObservations): sc = predictionAndObservations.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("observation", DoubleType(), nullable=False)])) @@ -197,7 +197,7 @@ class MulticlassMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)])) @@ -338,7 +338,7 @@ class RankingMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_model = callMLlibFunc("newRankingMetrics", df._jdf) @@ -424,7 +424,7 @@ class MultilabelMetrics(JavaModelWrapper): def __init__(self, predictionAndLabels): sc = predictionAndLabels.ctx - sql_ctx = SQLContext(sc) + sql_ctx = SQLContext.getOrCreate(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=sql_ctx._inferSchema(predictionAndLabels)) java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 7254679ebb53..acd7ec57d69d 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JJavaError -from pyspark import SparkContext, since +from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( @@ -100,8 +100,6 @@ def transform(self, vector): :return: normalized vector. If the norm of the input is zero, it will return the input vector. """ - sc = SparkContext._active_spark_context - assert sc is not None, "SparkContext should be initialized first" if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: From 861549acdbc11920cde51fc57752a8bc241064e5 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 16 Dec 2015 16:13:48 -0800 Subject: [PATCH 1805/1824] [MINOR] Add missing interpolation in NettyRPCEnv ``` Exception in thread "main" org.apache.spark.rpc.RpcTimeoutException: Cannot receive any reply in ${timeout.duration}. This timeout is controlled by spark.rpc.askTimeout at org.apache.spark.rpc.RpcTimeout.org$apache$spark$rpc$RpcTimeout$$createRpcTimeoutException(RpcTimeout.scala:48) at org.apache.spark.rpc.RpcTimeout$$anonfun$addMessageIfTimeout$1.applyOrElse(RpcTimeout.scala:63) at org.apache.spark.rpc.RpcTimeout$$anonfun$addMessageIfTimeout$1.applyOrElse(RpcTimeout.scala:59) at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:33) ``` Author: Andrew Or Closes #10334 from andrewor14/rpc-typo. --- .../src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index f82fd4eb5756..de3db6ba624f 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -232,7 +232,7 @@ private[netty] class NettyRpcEnv( val timeoutCancelable = timeoutScheduler.schedule(new Runnable { override def run(): Unit = { promise.tryFailure( - new TimeoutException("Cannot receive any reply in ${timeout.duration}")) + new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) } }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) promise.future.onComplete { v => From ce5fd4008e890ef8ebc2d3cb703a666783ad6c02 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 16 Dec 2015 17:05:57 -0800 Subject: [PATCH 1806/1824] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #1217 (requested by ankurdave, srowen) Closes #4650 (requested by andrewor14) Closes #5307 (requested by vanzin) Closes #5664 (requested by andrewor14) Closes #5713 (requested by marmbrus) Closes #5722 (requested by andrewor14) Closes #6685 (requested by srowen) Closes #7074 (requested by srowen) Closes #7119 (requested by andrewor14) Closes #7997 (requested by jkbradley) Closes #8292 (requested by srowen) Closes #8975 (requested by andrewor14, vanzin) Closes #8980 (requested by andrewor14, davies) From 38d9795a4fa07086d65ff705ce86648345618736 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 16 Dec 2015 19:01:05 -0800 Subject: [PATCH 1807/1824] [SPARK-10248][CORE] track exceptions in dagscheduler event loop in tests `DAGSchedulerEventLoop` normally only logs errors (so it can continue to process more events, from other jobs). However, this is not desirable in the tests -- the tests should be able to easily detect any exception, and also shouldn't silently succeed if there is an exception. This was suggested by mateiz on https://github.com/apache/spark/pull/7699. It may have already turned up an issue in "zero split job". Author: Imran Rashid Closes #8466 from squito/SPARK-10248. --- .../apache/spark/scheduler/DAGScheduler.scala | 5 ++-- .../spark/scheduler/DAGSchedulerSuite.scala | 28 +++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8d0e0c8624a5..b128ed50cad5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -805,7 +805,8 @@ class DAGScheduler( private[scheduler] def cleanUpAfterSchedulerStop() { for (job <- activeJobs) { - val error = new SparkException("Job cancelled because SparkContext was shut down") + val error = + new SparkException(s"Job ${job.jobId} cancelled because SparkContext was shut down") job.listener.jobFailed(error) // Tell the listeners that all of the running stages have ended. Don't bother // cancelling the stages because if the DAG scheduler is stopped, the entire application @@ -1295,7 +1296,7 @@ class DAGScheduler( case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case other => + case _: ExecutorLostFailure | TaskKilled | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 653d41fc053c..2869f0fde4c5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -45,6 +45,13 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) case NonFatal(e) => onError(e) } } + + override def onError(e: Throwable): Unit = { + logError("Error in DAGSchedulerEventLoop: ", e) + dagScheduler.stop() + throw e + } + } /** @@ -300,13 +307,18 @@ class DAGSchedulerSuite test("zero split job") { var numResults = 0 + var failureReason: Option[Exception] = None val fakeListener = new JobListener() { - override def taskSucceeded(partition: Int, value: Any) = numResults += 1 - override def jobFailed(exception: Exception) = throw exception + override def taskSucceeded(partition: Int, value: Any): Unit = numResults += 1 + override def jobFailed(exception: Exception): Unit = { + failureReason = Some(exception) + } } val jobId = submit(new MyRDD(sc, 0, Nil), Array(), listener = fakeListener) assert(numResults === 0) cancel(jobId) + assert(failureReason.isDefined) + assert(failureReason.get.getMessage() === "Job 0 cancelled ") } test("run trivial job") { @@ -1675,6 +1687,18 @@ class DAGSchedulerSuite assert(stackTraceString.contains("org.scalatest.FunSuite")) } + test("catch errors in event loop") { + // this is a test of our testing framework -- make sure errors in event loop don't get ignored + + // just run some bad event that will throw an exception -- we'll give a null TaskEndReason + val rdd1 = new MyRDD(sc, 1, Nil) + submit(rdd1, Array(0)) + intercept[Exception] { + complete(taskSets(0), Seq( + (null, makeMapStatus("hostA", 1)))) + } + } + test("simple map stage submission") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) From f590178d7a06221a93286757c68b23919bee9f03 Mon Sep 17 00:00:00 2001 From: tedyu Date: Wed, 16 Dec 2015 19:02:12 -0800 Subject: [PATCH 1808/1824] [SPARK-12365][CORE] Use ShutdownHookManager where Runtime.getRuntime.addShutdownHook() is called SPARK-9886 fixed ExternalBlockStore.scala This PR fixes the remaining references to Runtime.getRuntime.addShutdownHook() Author: tedyu Closes #10325 from ted-yu/master. --- .../spark/deploy/ExternalShuffleService.scala | 18 +++++--------- .../deploy/mesos/MesosClusterDispatcher.scala | 13 ++++------ .../spark/util/ShutdownHookManager.scala | 4 ++++ scalastyle-config.xml | 12 ++++++++++ .../hive/thriftserver/SparkSQLCLIDriver.scala | 24 +++++++++---------- 5 files changed, 38 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index e8a1e35c3fc4..7fc96e4f764b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -28,7 +28,7 @@ import org.apache.spark.network.sasl.SaslServerBootstrap import org.apache.spark.network.server.{TransportServerBootstrap, TransportServer} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.util.TransportConf -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * Provides a server from which Executors can read shuffle files (rather than reading directly from @@ -118,19 +118,13 @@ object ExternalShuffleService extends Logging { server = newShuffleService(sparkConf, securityManager) server.start() - installShutdownHook() + ShutdownHookManager.addShutdownHook { () => + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } // keep running until the process is terminated barrier.await() } - - private def installShutdownHook(): Unit = { - Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { - override def run() { - logInfo("Shutting down shuffle service.") - server.stop() - barrier.countDown() - } - }) - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 5d4e5b899dfd..389eff5e0645 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -22,7 +22,7 @@ import java.util.concurrent.CountDownLatch import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.SignalLogger +import org.apache.spark.util.{ShutdownHookManager, SignalLogger} import org.apache.spark.{Logging, SecurityManager, SparkConf} /* @@ -103,14 +103,11 @@ private[mesos] object MesosClusterDispatcher extends Logging { } val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) dispatcher.start() - val shutdownHook = new Thread() { - override def run() { - logInfo("Shutdown hook is shutting down dispatcher") - dispatcher.stop() - dispatcher.awaitShutdown() - } + ShutdownHookManager.addShutdownHook { () => + logInfo("Shutdown hook is shutting down dispatcher") + dispatcher.stop() + dispatcher.awaitShutdown() } - Runtime.getRuntime.addShutdownHook(shutdownHook) dispatcher.awaitShutdown() } } diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 620f226a23e1..1a0f3b477ba3 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -162,7 +162,9 @@ private[spark] object ShutdownHookManager extends Logging { val hook = new Thread { override def run() {} } + // scalastyle:off runtimeaddshutdownhook Runtime.getRuntime.addShutdownHook(hook) + // scalastyle:on runtimeaddshutdownhook Runtime.getRuntime.removeShutdownHook(hook) } catch { case ise: IllegalStateException => return true @@ -228,7 +230,9 @@ private [util] class SparkShutdownHookManager { .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) case Failure(_) => + // scalastyle:off runtimeaddshutdownhook Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + // scalastyle:on runtimeaddshutdownhook } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index dab1ebddc666..6925e18737b7 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -157,6 +157,18 @@ This file is divided into 3 sections: ]]> + + Runtime\.getRuntime\.addShutdownHook + + + Class\.forName - try { - h.flush() - } catch { - case e: IOException => - logWarning("WARNING: Failed to write command history file: " + e.getMessage) - } - case _ => - } + ShutdownHookManager.addShutdownHook { () => + reader.getHistory match { + case h: FileHistory => + try { + h.flush() + } catch { + case e: IOException => + logWarning("WARNING: Failed to write command history file: " + e.getMessage) + } + case _ => } - })) + } // TODO: missing /* From fdb38227564c1af40cbfb97df420b23eb04c002b Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Wed, 16 Dec 2015 19:04:33 -0800 Subject: [PATCH 1809/1824] [SPARK-12186][WEB UI] Send the complete request URI including the query string when redirecting. Author: Rohit Agarwal Closes #10180 from mindprince/SPARK-12186. --- .../scala/org/apache/spark/deploy/history/HistoryServer.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d4f327cc588f..f31fef0eccc3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -103,7 +103,9 @@ class HistoryServer( // Note we don't use the UI retrieved from the cache; the cache loader above will register // the app's UI, and all we need to do is redirect the user to the same URI that was // requested, and the proper data should be served at that point. - res.sendRedirect(res.encodeRedirectURL(req.getRequestURI())) + // Also, make sure that the redirect url contains the query string present in the request. + val requestURI = req.getRequestURI + Option(req.getQueryString).map("?" + _).getOrElse("") + res.sendRedirect(res.encodeRedirectURL(requestURI)) } // SPARK-5983 ensure TRACE is not supported From d1508dd9b765489913bc948575a69ebab82f217b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 16 Dec 2015 19:47:49 -0800 Subject: [PATCH 1810/1824] [SPARK-12386][CORE] Fix NPE when spark.executor.port is set. Author: Marcelo Vanzin Closes #10339 from vanzin/SPARK-12386. --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 84230e32a446..52acde1b414e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -256,7 +256,12 @@ object SparkEnv extends Logging { if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem } else { - val actorSystemPort = if (port == 0) 0 else rpcEnv.address.port + 1 + val actorSystemPort = + if (port == 0 || rpcEnv.address == null) { + port + } else { + rpcEnv.address.port + 1 + } // Create a ActorSystem for legacy codes AkkaUtils.createActorSystem( actorSystemName + "ActorSystem", From 97678edeaaafc19ea18d044233a952d2e2e89fbc Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 16 Dec 2015 20:01:47 -0800 Subject: [PATCH 1811/1824] [SPARK-12390] Clean up unused serializer parameter in BlockManager No change in functionality is intended. This only changes internal API. Author: Andrew Or Closes #10343 from andrewor14/clean-bm-serializer. --- .../apache/spark/storage/BlockManager.scala | 29 +++++++------------ .../org/apache/spark/storage/DiskStore.scala | 10 ------- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 540e1ec003a2..6074fc58d70d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1190,20 +1190,16 @@ private[spark] class BlockManager( def dataSerializeStream( blockId: BlockId, outputStream: OutputStream, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): Unit = { + values: Iterator[Any]): Unit = { val byteStream = new BufferedOutputStream(outputStream) - val ser = serializer.newInstance() + val ser = defaultSerializer.newInstance() ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } /** Serializes into a byte buffer. */ - def dataSerialize( - blockId: BlockId, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): ByteBuffer = { + def dataSerialize(blockId: BlockId, values: Iterator[Any]): ByteBuffer = { val byteStream = new ByteBufferOutputStream(4096) - dataSerializeStream(blockId, byteStream, values, serializer) + dataSerializeStream(blockId, byteStream, values) byteStream.toByteBuffer } @@ -1211,24 +1207,21 @@ private[spark] class BlockManager( * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserialize( - blockId: BlockId, - bytes: ByteBuffer, - serializer: Serializer = defaultSerializer): Iterator[Any] = { + def dataDeserialize(blockId: BlockId, bytes: ByteBuffer): Iterator[Any] = { bytes.rewind() - dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer) + dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true)) } /** * Deserializes a InputStream into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserializeStream( - blockId: BlockId, - inputStream: InputStream, - serializer: Serializer = defaultSerializer): Iterator[Any] = { + def dataDeserializeStream(blockId: BlockId, inputStream: InputStream): Iterator[Any] = { val stream = new BufferedInputStream(inputStream) - serializer.newInstance().deserializeStream(wrapForCompression(blockId, stream)).asIterator + defaultSerializer + .newInstance() + .deserializeStream(wrapForCompression(blockId, stream)) + .asIterator } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index c008b9dc1632..6c4477184d5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -144,16 +144,6 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) } - /** - * A version of getValues that allows a custom serializer. This is used as part of the - * shuffle short-circuit code. - */ - def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - // TODO: Should bypass getBytes and use a stream based implementation, so that - // we won't use a lot of memory during e.g. external sort merge. - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) - } - override def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) if (file.exists()) { From 437583f692e30b8dc03b339a34e92595d7b992ba Mon Sep 17 00:00:00 2001 From: David Tolpin Date: Wed, 16 Dec 2015 22:10:24 -0800 Subject: [PATCH 1812/1824] [SPARK-11904][PYSPARK] reduceByKeyAndWindow does not require checkpointing when invFunc is None when invFunc is None, `reduceByKeyAndWindow(func, None, winsize, slidesize)` is equivalent to reduceByKey(func).window(winsize, slidesize).reduceByKey(winsize, slidesize) and no checkpoint is necessary. The corresponding Scala code does exactly that, but Python code always creates a windowed stream with obligatory checkpointing. The patch fixes this. I do not know how to unit-test this. Author: David Tolpin Closes #9888 from dtolpin/master. --- python/pyspark/streaming/dstream.py | 45 +++++++++++++++-------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index f61137cb88c4..b994a53bf2b8 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -542,31 +542,32 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None reduced = self.reduceByKey(func, numPartitions) - def reduceFunc(t, a, b): - b = b.reduceByKey(func, numPartitions) - r = a.union(b).reduceByKey(func, numPartitions) if a else b - if filterFunc: - r = r.filter(filterFunc) - return r - - def invReduceFunc(t, a, b): - b = b.reduceByKey(func, numPartitions) - joined = a.leftOuterJoin(b, numPartitions) - return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) - if kv[1] is not None else kv[0]) - - jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) if invFunc: + def reduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r + + def invReduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + joined = a.leftOuterJoin(b, numPartitions) + return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) + if kv[1] is not None else kv[0]) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) + if slideDuration is None: + slideDuration = self._slideDuration + dstream = self._sc._jvm.PythonReducedWindowedDStream( + reduced._jdstream.dstream(), + jreduceFunc, jinvReduceFunc, + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) else: - jinvReduceFunc = None - if slideDuration is None: - slideDuration = self._slideDuration - dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), - jreduceFunc, jinvReduceFunc, - self._ssc._jduration(windowDuration), - self._ssc._jduration(slideDuration)) - return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + return reduced.window(windowDuration, slideDuration).reduceByKey(func, numPartitions) def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): """ From 9d66c4216ad830812848c657bbcd8cd50949e199 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 16 Dec 2015 23:18:53 -0800 Subject: [PATCH 1813/1824] [SPARK-12057][SQL] Prevent failure on corrupt JSON records This PR makes JSON parser and schema inference handle more cases where we have unparsed records. It is based on #10043. The last commit fixes the failed test and updates the logic of schema inference. Regarding the schema inference change, if we have something like ``` {"f1":1} [1,2,3] ``` originally, we will get a DF without any column. After this change, we will get a DF with columns `f1` and `_corrupt_record`. Basically, for the second row, `[1,2,3]` will be the value of `_corrupt_record`. When merge this PR, please make sure that the author is simplyianm. JIRA: https://issues.apache.org/jira/browse/SPARK-12057 Closes #10043 Author: Ian Macalinao Author: Yin Huai Closes #10288 from yhuai/handleCorruptJson. --- .../datasources/json/InferSchema.scala | 37 +++++++++++++++++-- .../datasources/json/JacksonParser.scala | 19 ++++++---- .../datasources/json/JsonSuite.scala | 37 +++++++++++++++++++ .../datasources/json/TestJsonData.scala | 9 ++++- 4 files changed, 90 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 922fd5b21167..59ba4ae2cba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -61,7 +61,10 @@ private[json] object InferSchema { StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) } } - }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) + }.treeAggregate[DataType]( + StructType(Seq()))( + compatibleRootType(columnNameOfCorruptRecords), + compatibleRootType(columnNameOfCorruptRecords)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -170,12 +173,38 @@ private[json] object InferSchema { case other => Some(other) } + private def withCorruptField( + struct: StructType, + columnNameOfCorruptRecords: String): StructType = { + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + struct.add(columnNameOfCorruptRecords, StringType, nullable = true) + } else { + // Otherwise, just return this struct. + struct + } + } + /** * Remove top-level ArrayType wrappers and merge the remaining schemas */ - private def compatibleRootType: (DataType, DataType) => DataType = { - case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2) - case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2) + private def compatibleRootType( + columnNameOfCorruptRecords: String): (DataType, DataType) => DataType = { + // Since we support array of json objects at the top level, + // we need to check the element type and find the root level data type. + case (ArrayType(ty1, _), ty2) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2) + case (ty1, ArrayType(ty2, _)) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2) + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + case (struct: StructType, NullType) => struct + case (NullType, struct: StructType) => struct + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, columnNameOfCorruptRecords) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, columnNameOfCorruptRecords) + // If we get anything else, we call compatibleType. + // Usually, when we reach here, ty1 and ty2 are two StructTypes. case (ty1, ty2) => compatibleType(ty1, ty2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index bfa140504105..55a1c24e9e00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils +private[json] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) + object JacksonParser { def parse( @@ -110,7 +112,7 @@ object JacksonParser { lowerCaseValue.equals("-inf")) { value.toFloat } else { - sys.error(s"Cannot parse $value as FloatType.") + throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") } case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) => @@ -127,7 +129,7 @@ object JacksonParser { lowerCaseValue.equals("-inf")) { value.toDouble } else { - sys.error(s"Cannot parse $value as DoubleType.") + throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") } case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) => @@ -174,7 +176,11 @@ object JacksonParser { convertField(factory, parser, udt.sqlType) case (token, dataType) => - sys.error(s"Failed to parse a value for data type $dataType (current token: $token).") + // We cannot parse this token based on the given data type. So, we throw a + // SparkSQLJsonProcessingException and this exception will be caught by + // parseJson method. + throw new SparkSQLJsonProcessingException( + s"Failed to parse a value for data type $dataType (current token: $token).") } } @@ -267,15 +273,14 @@ object JacksonParser { array.toArray[InternalRow](schema) } case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of " + - "the file (or each string in the RDD) is a valid JSON object or " + - "an array of JSON objects.") + failedRecord(record) } } } catch { case _: JsonProcessingException => failedRecord(record) + case _: SparkSQLJsonProcessingException => + failedRecord(record) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index ba7718c86463..baa258ad2615 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1427,4 +1427,41 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } } + + test("SPARK-12057 additional corrupt records do not throw exceptions") { + // Test if we can query corrupt records. + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempTable("jsonTable") { + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("dummy", StringType, true) :: Nil) + + { + // We need to make sure we can infer the schema. + val jsonDF = sqlContext.read.json(additionalCorruptRecords) + assert(jsonDF.schema === schema) + } + + { + val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords) + jsonDF.registerTempTable("jsonTable") + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT dummy, _unparsed + |FROM jsonTable + """.stripMargin), + Row("test", null) :: + Row(null, """[1,2,3]""") :: + Row(null, """":"test", "a":1}""") :: + Row(null, """42""") :: + Row(null, """ ","ian":"test"}""") :: Nil + ) + } + } + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 713d1da1cb51..cb61f7eeca0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -188,6 +188,14 @@ private[json] trait TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) + def additionalCorruptRecords: RDD[String] = + sqlContext.sparkContext.parallelize( + """{"dummy":"test"}""" :: + """[1,2,3]""" :: + """":"test", "a":1}""" :: + """42""" :: + """ ","ian":"test"}""" :: Nil) + def emptyRecords: RDD[String] = sqlContext.sparkContext.parallelize( """{""" :: @@ -197,7 +205,6 @@ private[json] trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) - lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) From 5a514b61bbfb609c505d8d65f2483068a56f1f70 Mon Sep 17 00:00:00 2001 From: echo2mei <534384876@qq.com> Date: Thu, 17 Dec 2015 07:59:17 -0800 Subject: [PATCH 1814/1824] Once driver register successfully, stop it to connect to master. This commit is to resolve SPARK-12396. Author: echo2mei <534384876@qq.com> Closes #10354 from echoTomei/master. --- .../main/scala/org/apache/spark/deploy/client/AppClient.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 1e2f469214b8..3cf7464a1561 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -130,6 +130,7 @@ private[spark] class AppClient( if (registered.get) { registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() + registrationRetryTimer.cancel(true) } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { From cd3d937b0cc89cb5e4098f7d9f5db2712e3de71e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 17 Dec 2015 08:01:27 -0800 Subject: [PATCH 1815/1824] Revert "Once driver register successfully, stop it to connect to master." This reverts commit 5a514b61bbfb609c505d8d65f2483068a56f1f70. --- .../main/scala/org/apache/spark/deploy/client/AppClient.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 3cf7464a1561..1e2f469214b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -130,7 +130,6 @@ private[spark] class AppClient( if (registered.get) { registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() - registrationRetryTimer.cancel(true) } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { From a170d34a1b309fecc76d1370063e0c4f44dc2142 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 17 Dec 2015 08:04:11 -0800 Subject: [PATCH 1816/1824] [SPARK-12395] [SQL] fix resulting columns of outer join For API DataFrame.join(right, usingColumns, joinType), if the joinType is right_outer or full_outer, the resulting join columns could be wrong (will be null). The order of columns had been changed to match that with MySQL and PostgreSQL [1]. This PR also fix the nullability of output for outer join. [1] http://www.postgresql.org/docs/9.2/static/queries-table-expressions.html Author: Davies Liu Closes #10353 from davies/fix_join. --- .../org/apache/spark/sql/DataFrame.scala | 25 +++++++++++++++---- .../apache/spark/sql/DataFrameJoinSuite.scala | 20 ++++++++++++--- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6250e952169d..d74131231499 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} @@ -455,10 +455,8 @@ class DataFrame private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] + Join(logicalPlan, right.logicalPlan, JoinType(joinType), None)).analyzed.asInstanceOf[Join] - // Project only one of the join columns. - val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col)) val condition = usingColumns.map { col => catalyst.expressions.EqualTo( withPlan(joined.left).resolve(col), @@ -467,9 +465,26 @@ class DataFrame private[sql]( catalyst.expressions.And(cond, eqTo) } + // Project only one of the join columns. + val joinedCols = JoinType(joinType) match { + case Inner | LeftOuter | LeftSemi => + usingColumns.map(col => withPlan(joined.left).resolve(col)) + case RightOuter => + usingColumns.map(col => withPlan(joined.right).resolve(col)) + case FullOuter => + usingColumns.map { col => + val leftCol = withPlan(joined.left).resolve(col) + val rightCol = withPlan(joined.right).resolve(col) + Alias(Coalesce(Seq(leftCol, rightCol)), col)() + } + } + // The nullability of output of joined could be different than original column, + // so we can only compare them by exprId + val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil) + val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId)) withPlan { Project( - joined.output.filterNot(joinedCols.contains(_)), + resultCols, Join( joined.left, joined.right, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c70397f9853a..39a65413bd59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -43,16 +43,28 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } test("join - join using multiple columns and specifying join type") { - val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") - val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "inner"), + Row(1, "1", 2, 3) :: Nil) checkAnswer( df.join(df2, Seq("int", "str"), "left"), - Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Nil) checkAnswer( df.join(df2, Seq("int", "str"), "right"), - Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + Row(1, "1", 2, 3) :: Row(5, "5", null, 6) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "outer"), + Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Row(5, "5", null, 6) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "left_semi"), + Row(1, "1", 2) :: Nil) } test("join - join using self join") { From 6e0771665b3c9330fc0a5b2c7740a796b4cd712e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Dec 2015 09:19:46 -0800 Subject: [PATCH 1817/1824] [SQL] Update SQLContext.read.text doc Since we rename the column name from ```text``` to ```value``` for DataFrame load by ```SQLContext.read.text```, we need to update doc. Author: Yanbo Liang Closes #10349 from yanboliang/text-value. --- python/pyspark/sql/readwriter.py | 2 +- .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 2 +- .../spark/sql/execution/datasources/text/DefaultSource.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 2e75f0c8a182..a3d7eca04b61 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -207,7 +207,7 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) def text(self, paths): - """Loads a text file and returns a [[DataFrame]] with a single string column named "text". + """Loads a text file and returns a [[DataFrame]] with a single string column named "value". Each line in the text file is a new row in the resulting DataFrame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3ed1e55adec6..c1a8f19313a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -339,7 +339,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { } /** - * Loads a text file and returns a [[DataFrame]] with a single string column named "text". + * Loads a text file and returns a [[DataFrame]] with a single string column named "value". * Each line in the text file is a new row in the resulting DataFrame. For example: * {{{ * // Scala: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index fbd387bc2ef4..4a1cbe4c38fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -76,7 +76,7 @@ private[sql] class TextRelation( (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { - /** Data schema is always a single column, named "text". */ + /** Data schema is always a single column, named "value". */ override def dataSchema: StructType = new StructType().add("value", StringType) /** This is an internal data source that outputs internal row format. */ From 86e405f357711ae93935853a912bc13985c259db Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 17 Dec 2015 09:55:37 -0800 Subject: [PATCH 1818/1824] [SPARK-12220][CORE] Make Utils.fetchFile support files that contain special characters This PR encodes and decodes the file name to fix the issue. Author: Shixiong Zhu Closes #10208 from zsxwing/uri. --- .../org/apache/spark/HttpFileServer.scala | 6 ++--- .../spark/rpc/netty/NettyStreamManager.scala | 5 ++-- .../scala/org/apache/spark/util/Utils.scala | 26 ++++++++++++++++++- .../org/apache/spark/rpc/RpcEnvSuite.scala | 4 +++ .../org/apache/spark/util/UtilsSuite.scala | 11 ++++++++ 5 files changed, 46 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 77d8ec9bb160..46f9f9e9af7d 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -63,12 +63,12 @@ private[spark] class HttpFileServer( def addFile(file: File) : String = { addFileToDir(file, fileDir) - serverUri + "/files/" + file.getName + serverUri + "/files/" + Utils.encodeFileNameToURIRawPath(file.getName) } def addJar(file: File) : String = { addFileToDir(file, jarDir) - serverUri + "/jars/" + file.getName + serverUri + "/jars/" + Utils.encodeFileNameToURIRawPath(file.getName) } def addDirectory(path: String, resourceBase: String): String = { @@ -85,7 +85,7 @@ private[spark] class HttpFileServer( throw new IllegalArgumentException(s"$file cannot be a directory.") } Files.copy(file, new File(dir, file.getName)) - dir + "/" + file.getName + dir + "/" + Utils.encodeFileNameToURIRawPath(file.getName) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index ecd96972455d..394cde4fa076 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc.RpcEnvFileServer +import org.apache.spark.util.Utils /** * StreamManager implementation for serving files from a NettyRpcEnv. @@ -64,13 +65,13 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) override def addFile(file: File): String = { require(files.putIfAbsent(file.getName(), file) == null, s"File ${file.getName()} already registered.") - s"${rpcEnv.address.toSparkURL}/files/${file.getName()}" + s"${rpcEnv.address.toSparkURL}/files/${Utils.encodeFileNameToURIRawPath(file.getName())}" } override def addJar(file: File): String = { require(jars.putIfAbsent(file.getName(), file) == null, s"JAR ${file.getName()} already registered.") - s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" + s"${rpcEnv.address.toSparkURL}/jars/${Utils.encodeFileNameToURIRawPath(file.getName())}" } override def addDirectory(baseUri: String, path: File): String = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 9dbe66e7eefb..fce89dfccfe2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -331,6 +331,30 @@ private[spark] object Utils extends Logging { } /** + * A file name may contain some invalid URI characters, such as " ". This method will convert the + * file name to a raw path accepted by `java.net.URI(String)`. + * + * Note: the file name must not contain "/" or "\" + */ + def encodeFileNameToURIRawPath(fileName: String): String = { + require(!fileName.contains("/") && !fileName.contains("\\")) + // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as + // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. + // We should remove it after we get the raw path. + new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1) + } + + /** + * Get the file name from uri's raw path and decode it. If the raw path of uri ends with "/", + * return the name before the last "/". + */ + def decodeFileNameInURI(uri: URI): String = { + val rawPath = uri.getRawPath + val rawFileName = rawPath.split("/").last + new URI("file:///" + rawFileName).getPath.substring(1) + } + + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -351,7 +375,7 @@ private[spark] object Utils extends Logging { hadoopConf: Configuration, timestamp: Long, useCache: Boolean) { - val fileName = url.split("/").last + val fileName = decodeFileNameInURI(new URI(url)) val targetFile = new File(targetDir, fileName) val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true) if (useCache && fetchCacheEnabled) { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6d153eb04e04..49e3e0191c38 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -771,6 +771,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val tempDir = Utils.createTempDir() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) + val fileWithSpecialChars = new File(tempDir, "file name") + Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) val empty = new File(tempDir, "empty") Files.write("", empty, UTF_8); val jar = new File(tempDir, "jar") @@ -787,6 +789,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) val fileUri = env.fileServer.addFile(file) + val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) @@ -805,6 +808,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val files = Seq( (file, fileUri), + (fileWithSpecialChars, fileWithSpecialCharsUri), (empty, emptyUri), (jar, jarUri), (subFile1, dir1Uri + "/file1"), diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 68b0da76bc13..fdb51d440eff 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -734,4 +734,15 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { conf.set("spark.executor.instances", "0")) === true) } + test("encodeFileNameToURIRawPath") { + assert(Utils.encodeFileNameToURIRawPath("abc") === "abc") + assert(Utils.encodeFileNameToURIRawPath("abc xyz") === "abc%20xyz") + assert(Utils.encodeFileNameToURIRawPath("abc:xyz") === "abc:xyz") + } + + test("decodeFileNameInURI") { + assert(Utils.decodeFileNameInURI(new URI("files:///abc/xyz")) === "xyz") + assert(Utils.decodeFileNameInURI(new URI("files:///abc")) === "abc") + assert(Utils.decodeFileNameInURI(new URI("files:///abc%20xyz")) === "abc xyz") + } } From 8184568810e8a2e7d5371db2c6a0366ef4841f70 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 18 Dec 2015 03:19:31 +0900 Subject: [PATCH 1819/1824] [SPARK-12345][MESOS] Properly filter out SPARK_HOME in the Mesos REST server Fix problem with #10332, this one should fix Cluster mode on Mesos Author: Iulian Dragos Closes #10359 from dragos/issue/fix-spark-12345-one-more-time. --- .../org/apache/spark/deploy/rest/mesos/MesosRestServer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 24510db2bd0b..c0b93596508f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -99,7 +99,7 @@ private[mesos] class MesosSubmitRequestServlet( // cause spark-submit script to look for files in SPARK_HOME instead. // We only need the ability to specify where to find spark-submit script // which user can user spark.executor.home or spark.home configurations. - val environmentVariables = request.environmentVariables.filter(!_.equals("SPARK_HOME")) + val environmentVariables = request.environmentVariables.filterKeys(!_.equals("SPARK_HOME")) val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description From 540b5aeadc84d1a5d61bda4414abd6bf35dc7ff9 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 17 Dec 2015 13:23:48 -0800 Subject: [PATCH 1820/1824] [SPARK-12410][STREAMING] Fix places that use '.' and '|' directly in split String.split accepts a regular expression, so we should escape "." and "|". Author: Shixiong Zhu Closes #10361 from zsxwing/reg-bug. --- .../main/scala/org/apache/spark/examples/ml/MovieLensALS.scala | 2 +- .../apache/spark/streaming/util/FileBasedWriteAheadLog.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 3ae53e57dbdb..02ed746954f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -50,7 +50,7 @@ object MovieLensALS { def parseMovie(str: String): Movie = { val fields = str.split("::") assert(fields.size == 3) - Movie(fields(0).toInt, fields(1), fields(2).split("|")) + Movie(fields(0).toInt, fields(1), fields(2).split("\\|")) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index a99b57083583..b946e0d8e927 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -253,7 +253,7 @@ private[streaming] object FileBasedWriteAheadLog { def getCallerName(): Option[String] = { val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName) - stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split(".").lastOption) + stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split("\\.").lastOption) } /** Convert a sequence of files to a sequence of sorted LogInfo objects */ From e096a652b92fc64a7b3457cd0766ab324bcc980b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Dec 2015 14:16:49 -0800 Subject: [PATCH 1821/1824] [SPARK-12397][SQL] Improve error messages for data sources when they are not found Point users to spark-packages.org to find them. Author: Reynold Xin Closes #10351 from rxin/SPARK-12397. --- .../datasources/ResolvedDataSource.scala | 50 ++++++++++++------- .../sql/sources/ResolvedDataSourceSuite.scala | 17 +++++++ 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 86a306b8f941..e02ee6cd6b90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -57,24 +57,38 @@ object ResolvedDataSource extends Logging { val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { - /** the provider format did not match any given registered aliases */ - case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => dataSource - case Failure(error) => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new ClassNotFoundException( - "The ORC data source must be used with Hive support enabled.", error) - } else { - throw new ClassNotFoundException( - s"Failed to load class for data source: $provider.", error) - } - } - /** there is exactly one registered alias */ - case head :: Nil => head.getClass - /** There are multiple registered aliases for the input */ - case sources => sys.error(s"Multiple sources found for $provider, " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name.") + // the provider format did not match any given registered aliases + case Nil => + Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => + // Found the data source using fully qualified path + dataSource + case Failure(error) => + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + if (provider == "avro" || provider == "com.databricks.spark.avro") { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please use Spark package " + + "http://spark-packages.org/package/databricks/spark-avro", + error) + } else { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please find packages at " + + "http://spark-packages.org", + error) + } + } + } + case head :: Nil => + // there is exactly one registered alias + head.getClass + case sources => + // There are multiple registered aliases for the input + sys.error(s"Multiple sources found for $provider " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 27d1cd92fca1..cb6e5179b31f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -57,4 +57,21 @@ class ResolvedDataSourceSuite extends SparkFunSuite { ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } + + test("error message for unknown data sources") { + val error1 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("avro") + } + assert(error1.getMessage.contains("spark-packages")) + + val error2 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("com.databricks.spark.avro") + } + assert(error2.getMessage.contains("spark-packages")) + + val error3 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("asfdwefasdfasdf") + } + assert(error3.getMessage.contains("spark-packages")) + } } From ed6ebda5c898bad76194fe3a090bef5a14f861c2 Mon Sep 17 00:00:00 2001 From: Evan Chen Date: Thu, 17 Dec 2015 14:22:30 -0800 Subject: [PATCH 1822/1824] [SPARK-12376][TESTS] Spark Streaming Java8APISuite fails in assertOrderInvariantEquals method org.apache.spark.streaming.Java8APISuite.java is failing due to trying to sort immutable list in assertOrderInvariantEquals method. Author: Evan Chen Closes #10336 from evanyc15/SPARK-12376-StreamingJavaAPISuite. --- .../org/apache/spark/streaming/Java8APISuite.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 89e0c7fdf7ee..e8a0dfc0f0a5 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -439,9 +439,14 @@ public void testPairFlatMap() { */ public static > void assertOrderInvariantEquals( List> expected, List> actual) { - expected.forEach((List list) -> Collections.sort(list)); - actual.forEach((List list) -> Collections.sort(list)); - Assert.assertEquals(expected, actual); + expected.forEach(list -> Collections.sort(list)); + List> sortedActual = new ArrayList<>(); + actual.forEach(list -> { + List sortedList = new ArrayList<>(list); + Collections.sort(sortedList); + sortedActual.add(sortedList); + }); + Assert.assertEquals(expected, sortedActual); } @Test From 658f66e6208a52367e3b43a6fee9c90f33fb6226 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 17 Dec 2015 15:16:35 -0800 Subject: [PATCH 1823/1824] [SPARK-8641][SQL] Native Spark Window functions This PR removes Hive windows functions from Spark and replaces them with (native) Spark ones. The PR is on par with Hive in terms of features. This has the following advantages: * Better memory management. * The ability to use spark UDAFs in Window functions. cc rxin / yhuai Author: Herman van Hovell Closes #9819 from hvanhovell/SPARK-8641-2. --- .../sql/catalyst/analysis/Analyzer.scala | 81 ++- .../sql/catalyst/analysis/CheckAnalysis.scala | 39 +- .../catalyst/analysis/FunctionRegistry.scala | 12 +- .../expressions/aggregate/interfaces.scala | 7 +- .../expressions/windowExpressions.scala | 318 +++++++-- .../analysis/AnalysisErrorSuite.scala | 31 +- .../apache/spark/sql/execution/Window.scala | 649 ++++++++++-------- .../spark/sql/expressions/WindowSpec.scala | 54 +- .../org/apache/spark/sql/functions.scala | 16 +- .../spark/sql/DataFrameWindowSuite.scala} | 200 +++--- .../HiveWindowFunctionQuerySuite.scala | 10 +- .../apache/spark/sql/hive/HiveContext.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 22 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 224 ------ .../sql/hive/execution/WindowQuerySuite.scala | 230 +++++++ 15 files changed, 1148 insertions(+), 746 deletions(-) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala => core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala} (57%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ca00a5e49f66..64dd83a91571 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -77,6 +77,8 @@ class Analyzer( ResolveGenerate :: ResolveFunctions :: ResolveAliases :: + ResolveWindowOrder :: + ResolveWindowFrame :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -127,14 +129,12 @@ class Analyzer( // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { - case plan => plan.transformExpressions { + case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => val errorMessage = s"Window specification $windowName is not defined in the WINDOW clause." val windowSpecDefinition = - windowDefinitions - .get(windowName) - .getOrElse(failAnalysis(errorMessage)) + windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) WindowExpression(c, windowSpecDefinition) } } @@ -577,6 +577,10 @@ class Analyzer( AggregateExpression(max, Complete, isDistinct = false) case min: Min if isDistinct => AggregateExpression(min, Complete, isDistinct = false) + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => wf // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. @@ -597,11 +601,17 @@ class Analyzer( } def containsAggregates(exprs: Seq[Expression]): Boolean = { - exprs.foreach(_.foreach { - case agg: AggregateExpression => return true - case _ => - }) - false + // Collect all Windowed Aggregate Expressions. + val windowedAggExprs = exprs.flatMap { expr => + expr.collect { + case WindowExpression(ae: AggregateExpression, _) => ae + } + }.toSet + + // Find the first Aggregate Expression that is not Windowed. + exprs.exists(_.collectFirst { + case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae + }.isDefined) } } @@ -875,26 +885,37 @@ class Analyzer( // Now, we extract regular expressions from expressionsWithWindowFunctions // by using extractExpr. + val seenWindowAggregates = new ArrayBuffer[AggregateExpression] val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). case wf : WindowFunction => - val newChildren = wf.children.map(extractExpr(_)) + val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) // Extracts expressions from the partition spec and order spec. case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => - val newPartitionSpec = partitionSpec.map(extractExpr(_)) + val newPartitionSpec = partitionSpec.map(extractExpr) val newOrderSpec = orderSpec.map { so => val newChild = extractExpr(so.child) so.copy(child = newChild) } wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) + // Extract Windowed AggregateExpression + case we @ WindowExpression( + AggregateExpression(function, mode, isDistinct), + spec: WindowSpecDefinition) => + val newChildren = function.children.map(extractExpr) + val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val newAgg = AggregateExpression(newFunction, mode, isDistinct) + seenWindowAggregates += newAgg + WindowExpression(newAgg, spec) + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). - case agg: AggregateExpression => + case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute @@ -1102,6 +1123,42 @@ class Analyzer( } } } + + /** + * Check and add proper window frames for all window functions. + */ + object ResolveWindowFrame extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, + WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + if wf.frame != UnspecifiedFrame && wf.frame != f => + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") + case WindowExpression(wf: WindowFunction, + s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + if wf.frame != UnspecifiedFrame => + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => + val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + we.copy(windowSpec = s.copy(frameSpecification = frame)) + } + } + } + + /** + * Check and add order to [[AggregateWindowFunction]]s. + */ + object ResolveWindowOrder extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"WindowFunction $wf requires window to be ordered") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b2c93d63d67..440f67991380 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -70,15 +70,32 @@ trait CheckAnalysis { failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => - failAnalysis( - s"Could not resolve window function '$name'. " + - "Note that, using window functions currently requires a HiveContext") + case w @ WindowExpression(AggregateExpression(_, _, true), _) => + failAnalysis(s"Distinct window functions are not supported: $w") + + case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, + SpecifiedWindowFrame(frame, + FrameBoundary(l), + FrameBoundary(h)))) + if order.isEmpty || frame != RowFrame || l != h => + failAnalysis("An offset window function can only be evaluated in an ordered " + + s"row-based window frame with a single offset: $w") + + case w @ WindowExpression(e, s) => + // Only allow window functions with an aggregate expression or an offset window + // function. + e match { + case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => + case _ => + failAnalysis(s"Expression '$e' not supported within a window function.") + } + // Make sure the window specification is valid. + s.validate match { + case Some(m) => + failAnalysis(s"Window specification $s is not valid because $m") + case None => w + } - case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => - // The window spec is not valid. - val reason = windowSpec.validate.get - failAnalysis(s"Window specification $windowSpec is not valid because $reason") } operator match { @@ -204,10 +221,12 @@ trait CheckAnalysis { s"unresolved operator ${operator.simpleString}") case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & !o.isInstanceOf[Aggregate] => + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && + !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => // The rule above is used to check Aggregate operator. failAnalysis( - s"""nondeterministic expressions are only allowed in Project or Filter, found: + s"""nondeterministic expressions are only allowed in + |Project, Filter, Aggregate or Window, found: | ${o.expressions.map(_.prettyString).mkString(",")} |in operator ${operator.simpleString} """.stripMargin) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f9c04d7ec0b0..12c24cc76822 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -283,7 +283,17 @@ object FunctionRegistry { expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), expression[InputFileName]("input_file_name"), - expression[MonotonicallyIncreasingID]("monotonically_increasing_id") + expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), + + // window functions + expression[Lead]("lead"), + expression[Lag]("lag"), + expression[RowNumber]("row_number"), + expression[CumeDist]("cume_dist"), + expression[NTile]("ntile"), + expression[Rank]("rank"), + expression[DenseRank]("dense_rank"), + expression[PercentRank]("percent_rank") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 3b441de34a49..b6d2ddc5b136 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -144,9 +144,6 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu */ def defaultResult: Option[Literal] = None - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - /** * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, @@ -187,7 +184,7 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` * and `inputAggBufferAttributes`. */ -abstract class ImperativeAggregate extends AggregateFunction { +abstract class ImperativeAggregate extends AggregateFunction with CodegenFallback { /** * The offset of this function's first buffer value in the underlying shared mutable aggregation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 1680aa8252ec..06252ac4fc61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{DataType, NumericType} +import org.apache.spark.sql.catalyst.expressions.aggregate.{NoOp, DeclarativeAggregate} +import org.apache.spark.sql.types._ /** * The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for @@ -117,6 +118,19 @@ sealed trait FrameBoundary { def notFollows(other: FrameBoundary): Boolean } +/** + * Extractor for making working with frame boundaries easier. + */ +object FrameBoundary { + def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) + def unapply(boundary: FrameBoundary): Option[Int] = boundary match { + case CurrentRow => Some(0) + case ValuePreceding(offset) => Some(-offset) + case ValueFollowing(offset) => Some(offset) + case _ => None + } +} + /** UNBOUNDED PRECEDING boundary. */ case object UnboundedPreceding extends FrameBoundary { def notFollows(other: FrameBoundary): Boolean = other match { @@ -243,85 +257,281 @@ object SpecifiedWindowFrame { } } +case class UnresolvedWindowExpression( + child: Expression, + windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} + +case class WindowExpression( + windowFunction: Expression, + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { + + override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil + + override def dataType: DataType = windowFunction.dataType + override def foldable: Boolean = windowFunction.foldable + override def nullable: Boolean = windowFunction.nullable + + override def toString: String = s"$windowFunction $windowSpec" +} + /** - * Every window function needs to maintain a output buffer for its output. - * It should expect that for a n-row window frame, it will be called n times - * to retrieve value corresponding with these n rows. + * A window function is a function that can only be evaluated in the context of a window operator. */ trait WindowFunction extends Expression { - def init(): Unit + /** Frame in which the window operator must be executed. */ + def frame: WindowFrame = UnspecifiedFrame +} + +/** + * An offset window function is a window function that returns the value of the input column offset + * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with + * offset -2, will get the value of x 2 rows back in the partition. + */ +abstract class OffsetWindowFunction + extends Expression with WindowFunction with Unevaluable with ImplicitCastInputTypes { + /** + * Input expression to evaluate against a row which a number of rows below or above (depending on + * the value and sign of the offset) the current row. + */ + val input: Expression + + /** + * Default result value for the function when the input expression returns NULL. The default will + * evaluated against the current row instead of the offset row. + */ + val default: Expression - def reset(): Unit + /** + * (Foldable) expression that contains the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val offset: Expression - def prepareInputParameters(input: InternalRow): AnyRef + /** + * Direction (above = 1/below = -1) of the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val direction: SortDirection - def update(input: AnyRef): Unit + override def children: Seq[Expression] = Seq(input, offset, default) - def batchUpdate(inputs: Array[AnyRef]): Unit + /* + * The result of an OffsetWindowFunction is dependent on the frame in which the + * OffsetWindowFunction is executed, the input expression and the default expression. Even when + * both the input and the default expression are foldable, the result is still not foldable due to + * the frame. + */ + override def foldable: Boolean = input.foldable && (default == null || default.foldable) - def evaluate(): Unit + override def nullable: Boolean = input.nullable && (default == null || default.nullable) - def get(index: Int): Any + override lazy val frame = { + // This will be triggered by the Analyzer. + val offsetValue = offset.eval() match { + case o: Int => o + case x => throw new AnalysisException( + s"Offset expression must be a foldable integer expression: $x") + } + val boundary = direction match { + case Ascending => ValueFollowing(offsetValue) + case Descending => ValuePreceding(offsetValue) + } + SpecifiedWindowFrame(RowFrame, boundary, boundary) + } + + override def dataType: DataType = input.dataType - def newInstance(): WindowFunction + override def inputTypes: Seq[AbstractDataType] = + Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType)) + + override def toString: String = s"$prettyName($input, $offset, $default)" } -case class UnresolvedWindowFunction( - name: String, - children: Seq[Expression]) - extends Expression with WindowFunction with Unevaluable { +case class Lead(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) - override def init(): Unit = throw new UnresolvedException(this, "init") - override def reset(): Unit = throw new UnresolvedException(this, "reset") - override def prepareInputParameters(input: InternalRow): AnyRef = - throw new UnresolvedException(this, "prepareInputParameters") - override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") - override def batchUpdate(inputs: Array[AnyRef]): Unit = - throw new UnresolvedException(this, "batchUpdate") - override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate") - override def get(index: Int): Any = throw new UnresolvedException(this, "get") + def this(input: Expression) = this(input, Literal(1)) - override def toString: String = s"'$name(${children.mkString(",")})" + def this() = this(Literal(null)) - override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance") + override val direction = Ascending } -case class UnresolvedWindowExpression( - child: UnresolvedWindowFunction, - windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { +case class Lag(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) + + def this(input: Expression) = this(input, Literal(1)) + + def this() = this(Literal(null)) + + override val direction = Descending } -case class WindowExpression( - windowFunction: WindowFunction, - windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { +abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction { + self: Product => + override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) + override def dataType: DataType = IntegerType + override def nullable: Boolean = false + override def supportsPartial: Boolean = false + override lazy val mergeExpressions = + throw new UnsupportedOperationException("Window Functions do not support merging.") +} - override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil +abstract class RowNumberLike extends AggregateWindowFunction { + override def children: Seq[Expression] = Nil + override def inputTypes: Seq[AbstractDataType] = Nil + protected val zero = Literal(0) + protected val one = Literal(1) + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil + override val initialValues: Seq[Expression] = zero :: Nil + override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil +} - override def dataType: DataType = windowFunction.dataType - override def foldable: Boolean = windowFunction.foldable - override def nullable: Boolean = windowFunction.nullable +/** + * A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation. + */ +trait SizeBasedWindowFunction extends AggregateWindowFunction { + protected def n: AttributeReference = SizeBasedWindowFunction.n +} - override def toString: String = s"$windowFunction $windowSpec" +object SizeBasedWindowFunction { + val n = AttributeReference("window__partition__size", IntegerType, nullable = false)() +} + +case class RowNumber() extends RowNumberLike { + override val evaluateExpression = rowNumber +} + +case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { + override def dataType: DataType = DoubleType + // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must + // return the same value for equal values in the partition. + override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) +} + +case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { + def this() = this(Literal(1)) + + // Validate buckets. Note that this could be relaxed, the bucket value only needs to constant + // for each partition. + buckets.eval() match { + case b: Int if b > 0 => // Ok + case x => throw new AnalysisException( + "Buckets expression must be a foldable positive integer expression: $x") + } + + private val bucket = AttributeReference("bucket", IntegerType, nullable = false)() + private val bucketThreshold = + AttributeReference("bucketThreshold", IntegerType, nullable = false)() + private val bucketSize = AttributeReference("bucketSize", IntegerType, nullable = false)() + private val bucketsWithPadding = + AttributeReference("bucketsWithPadding", IntegerType, nullable = false)() + private def bucketOverflow(e: Expression) = + If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero) + + override val aggBufferAttributes = Seq( + rowNumber, + bucket, + bucketThreshold, + bucketSize, + bucketsWithPadding + ) + + override val initialValues = Seq( + zero, + zero, + zero, + Cast(Divide(n, buckets), IntegerType), + Cast(Remainder(n, buckets), IntegerType) + ) + + override val updateExpressions = Seq( + Add(rowNumber, one), + Add(bucket, bucketOverflow(one)), + Add(bucketThreshold, bucketOverflow( + Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))), + NoOp, + NoOp + ) + + override val evaluateExpression = bucket } /** - * Extractor for making working with frame boundaries easier. + * A RankLike function is a WindowFunction that changes its value based on a change in the value of + * the order of the window in which is processed. For instance, when the value of 'x' changes in a + * window ordered by 'x' the rank function also changes. The size of the change of the rank function + * is (typically) not dependent on the size of the change in 'x'. */ -object FrameBoundaryExtractor { - def unapply(boundary: FrameBoundary): Option[Int] = boundary match { - case CurrentRow => Some(0) - case ValuePreceding(offset) => Some(-offset) - case ValueFollowing(offset) => Some(offset) - case _ => None +abstract class RankLike extends AggregateWindowFunction { + override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) + + /** Store the values of the window 'order' expressions. */ + protected val orderAttrs = children.map{ expr => + AttributeReference(expr.prettyString, expr.dataType)() } + + /** Predicate that detects if the order attributes have changed. */ + protected val orderEquals = children.zip(orderAttrs) + .map(EqualNullSafe.tupled) + .reduceOption(And) + .getOrElse(Literal(true)) + + protected val orderInit = children.map(e => Literal.create(null, e.dataType)) + protected val rank = AttributeReference("rank", IntegerType, nullable = false)() + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + protected val zero = Literal(0) + protected val one = Literal(1) + protected val increaseRowNumber = Add(rowNumber, one) + + /** + * Different RankLike implementations use different source expressions to update their rank value. + * Rank for instance uses the number of rows seen, whereas DenseRank uses the number of changes. + */ + protected def rankSource: Expression = rowNumber + + /** Increase the rank when the current rank == 0 or when the one of order attributes changes. */ + protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), rank, rankSource) + + override val aggBufferAttributes: Seq[AttributeReference] = rank +: rowNumber +: orderAttrs + override val initialValues = zero +: one +: orderInit + override val updateExpressions = increaseRank +: increaseRowNumber +: children + override val evaluateExpression: Expression = rank + + def withOrder(order: Seq[Expression]): RankLike +} + +case class Rank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): Rank = Rank(order) +} + +case class DenseRank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) + override protected def rankSource = Add(rank, one) + override val updateExpressions = increaseRank +: children + override val aggBufferAttributes = rank +: orderAttrs + override val initialValues = zero +: orderInit +} + +case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBasedWindowFunction { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) + override def dataType: DataType = DoubleType + override val evaluateExpression = If(GreaterThan(n, one), + Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), + Literal(0.0d)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ee435578743f..12079992b5b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Count, Sum, AggregateExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -133,17 +134,37 @@ class AnalysisErrorSuite extends AnalysisTest { "requires int type" :: "'null' is of date type" :: Nil) errorTest( - "unresolved window function", + "invalid window function", testRelation2.select( WindowExpression( - UnresolvedWindowFunction( - "lead", - UnresolvedAttribute("c") :: Nil), + Literal(0), WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as('window)), - "lead" :: "window functions currently requires a HiveContext" :: Nil) + "not supported within a window function" :: Nil) + + errorTest( + "distinct window function", + testRelation2.select( + WindowExpression( + AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "Distinct window functions are not supported" :: Nil) + + errorTest( + "offset window function", + testRelation2.select( + WindowExpression( + new Lead(UnresolvedAttribute("b")), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + SpecifiedWindowFrame(RangeFrame, ValueFollowing(1), ValueFollowing(2)))).as('window)), + "window frame" :: "must match the required frame" :: Nil) errorTest( "too many generators", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index b1280c32a6a4..9852b6e7beeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType import org.apache.spark.rdd.RDD -import org.apache.spark.util.collection.CompactBuffer /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -42,6 +45,8 @@ import org.apache.spark.util.collection.CompactBuffer * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame * and we add some rows to the frame. Examples are: * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * - Offset frame: The frame consist of one row, which is an offset number of rows away from the + * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame. * * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame * boundary can be either Row or Range based: @@ -122,12 +127,10 @@ case class Window( // Create the projection which returns the current 'value'. val current = newMutableProjection(expr :: Nil, child.output)() // Flip the sign of the offset when processing the order is descending - val boundOffset = - if (sortExpr.direction == Descending) { - -offset - } else { - offset - } + val boundOffset = sortExpr.direction match { + case Descending => -offset + case Ascending => offset + } // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) val bound = newMutableProjection(boundExpr :: Nil, child.output)() @@ -149,43 +152,102 @@ case class Window( } /** - * Create a frame processor. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame boundaries. - * @param functions to process in the frame. - * @param ordinal at which the processor starts writing to the output. - * @return a frame processor. + * Collection containing an entry for each window frame to process. Each entry contains a frames' + * WindowExpressions and factory function for the WindowFrameFunction. */ - private[this] def createFrameProcessor( - frame: WindowFrame, - functions: Array[WindowFunction], - ordinal: Int): WindowFunctionFrame = frame match { - // Growing Frame. - case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => - val uBoundOrdering = createBoundOrdering(frameType, high) - new UnboundedPrecedingWindowFunctionFrame(ordinal, functions, uBoundOrdering) - - // Shrinking Frame. - case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => - val lBoundOrdering = createBoundOrdering(frameType, low) - new UnboundedFollowingWindowFunctionFrame(ordinal, functions, lBoundOrdering) - - // Moving Frame. - case SpecifiedWindowFrame(frameType, - FrameBoundaryExtractor(low), FrameBoundaryExtractor(high)) => - val lBoundOrdering = createBoundOrdering(frameType, low) - val uBoundOrdering = createBoundOrdering(frameType, high) - new SlidingWindowFunctionFrame(ordinal, functions, lBoundOrdering, uBoundOrdering) - - // Entire Partition Frame. - case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => - new UnboundedWindowFunctionFrame(ordinal, functions) - - // Error - case fr => - sys.error(s"Unsupported Frame $fr for functions: $functions") + private[this] lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es.append(e) + fns.append(fn) + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e @ WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f) + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) + case f => sys.error(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection) + + // Create the factory + val factory = key match { + // Offset Frame + case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + target: MutableRow => + new OffsetWindowFunctionFrame( + target, + ordinal, + functions, + child.output, + newMutableProjection, + offset) + + // Growing Frame. + case ("AGGREGATE", frameType, None, Some(high)) => + target: MutableRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, high)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, Some(low), None) => + target: MutableRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, Some(low), Some(high)) => + target: MutableRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low), + createBoundOrdering(frameType, high)) + } + + // Entire Partition Frame. + case ("AGGREGATE", frameType, None, None) => + target: MutableRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Frame Expression - Factory pair. + (expressions, factory) + } } /** @@ -210,43 +272,16 @@ case class Window( } protected override def doExecute(): RDD[InternalRow] = { - // Prepare processing. - // Group the window expression by their processing frame. - val windowExprs = windowExpression.flatMap { - _.collect { - case e: WindowExpression => e - } - } - - // Create Frame processor factories and order the unbound window expressions by the frame they - // are processed in; this is the order in which their results will be written to window - // function result buffer. - val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) - val factories = Array.ofDim[() => WindowFunctionFrame](framedWindowExprs.size) - val unboundExpressions = scala.collection.mutable.Buffer.empty[Expression] - framedWindowExprs.zipWithIndex.foreach { - case ((frame, unboundFrameExpressions), index) => - // Track the ordinal. - val ordinal = unboundExpressions.size - - // Track the unbound expressions - unboundExpressions ++= unboundFrameExpressions - - // Bind the expressions. - val functions = unboundFrameExpressions.map { e => - BindReferences.bindReference(e.windowFunction, child.output) - }.toArray - - // Create the frame processor factory. - factories(index) = () => createFrameProcessor(frame, functions, ordinal) - } + // Unwrap the expressions and factories from the map. + val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray // Start processing. child.execute().mapPartitions { stream => new Iterator[InternalRow] { // Get all relevant projections. - val result = createResultProjection(unboundExpressions) + val result = createResultProjection(expressions) val grouping = UnsafeProjection.create(partitionSpec, child.output) // Manage the stream and the grouping. @@ -266,14 +301,15 @@ case class Window( fetchNextRow() // Manage the current partition. - var rows: CompactBuffer[InternalRow] = _ - val frames: Array[WindowFunctionFrame] = factories.map(_()) + val rows = ArrayBuffer.empty[InternalRow] + val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) + val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. // Before we start to fetch new input rows, make a copy of nextGroup. val currentGroup = nextGroup.copy() - rows = new CompactBuffer + rows.clear() while (nextRowAvailable && nextGroup == currentGroup) { rows += nextRow.copy() fetchNextRow() @@ -297,7 +333,6 @@ case class Window( override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable val join = new JoinedRow - val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) override final def next(): InternalRow = { // Load the next partition if we need to. if (rowIndex >= rowsSize && nextRowAvailable) { @@ -308,7 +343,7 @@ case class Window( // Get the results for the window frames. var i = 0 while (i < numFrames) { - frames(i).write(windowFunctionResult) + frames(i).write() i += 1 } @@ -355,140 +390,96 @@ private[execution] final case class RangeBoundOrdering( * A window function calculates the results of a number of window functions for a window frame. * Before use a frame must be prepared by passing it all the rows in the current partition. After * preparation the update method can be called to fill the output rows. - * - * TODO How to improve performance? A few thoughts: - * - Window functions are expensive due to its distribution and ordering requirements. - * Unfortunately it is up to the Spark engine to solve this. Improvements in the form of project - * Tungsten are on the way. - * - The window frame processing bit can be improved though. But before we start doing that we - * need to see how much of the time and resources are spent on partitioning and ordering, and - * how much time and resources are spent processing the partitions. There are a couple ways to - * improve on the current situation: - * - Reduce memory footprint by performing streaming calculations. This can only be done when - * there are no Unbound/Unbounded Following calculations present. - * - Use Tungsten style memory usage. - * - Use code generation in general, and use the approach to aggregation taken in the - * GeneratedAggregate class in specific. - * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. */ -private[execution] abstract class WindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) { - - // Make sure functions are initialized. - functions.foreach(_.init()) - - /** Number of columns the window function frame is managing */ - val numColumns = functions.length - - /** - * Create a fresh thread safe copy of the frame. - * - * @return the copied frame. - */ - def copy: WindowFunctionFrame - - /** - * Create new instances of the functions. - * - * @return an array containing copies of the current window functions. - */ - protected final def copyFunctions: Array[WindowFunction] = functions.map(_.newInstance()) - +private[execution] abstract class WindowFunctionFrame { /** * Prepare the frame for calculating the results for a partition. * * @param rows to calculate the frame results for. */ - def prepare(rows: CompactBuffer[InternalRow]): Unit + def prepare(rows: ArrayBuffer[InternalRow]): Unit /** - * Write the result for the current row to the given target row. - * - * @param target row to write the result for the current row to. + * Write the current results to the target row. */ - def write(target: GenericMutableRow): Unit + def write(): Unit +} - /** Reset the current window functions. */ - protected final def reset(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).reset() - i += 1 - } - } +/** + * The offset window frame calculates frames containing LEAD/LAG statements. + * + * @param target to write results to. + * @param expressions to shift a number of rows. + * @param inputSchema required for creating a projection. + * @param newMutableProjection function used to create the projection. + * @param offset by which rows get moved within a partition. + */ +private[execution] final class OffsetWindowFunctionFrame( + target: MutableRow, + ordinal: Int, + expressions: Array[Expression], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, + offset: Int) extends WindowFunctionFrame { - /** Prepare an input row for processing. */ - protected final def prepare(input: InternalRow): Array[AnyRef] = { - val prepared = new Array[AnyRef](numColumns) - var i = 0 - while (i < numColumns) { - prepared(i) = functions(i).prepareInputParameters(input) - i += 1 - } - prepared - } + /** Rows of the partition currently being processed. */ + private[this] var input: ArrayBuffer[InternalRow] = null - /** Evaluate a prepared buffer (iterator). */ - protected final def evaluatePrepared(iterator: java.util.Iterator[Array[AnyRef]]): Unit = { - reset() - while (iterator.hasNext) { - val prepared = iterator.next() - var i = 0 - while (i < numColumns) { - functions(i).update(prepared(i)) - i += 1 - } - } - evaluate() - } + /** Index of the input row currently used for output. */ + private[this] var inputIndex = 0 - /** Evaluate a prepared buffer (array). */ - protected final def evaluatePrepared(prepared: Array[Array[AnyRef]], - fromIndex: Int, toIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - val function = functions(i) - function.reset() - var j = fromIndex - while (j < toIndex) { - function.update(prepared(j)(i)) - j += 1 - } - function.evaluate() - i += 1 - } - } + /** Index of the current output row. */ + private[this] var outputIndex = 0 - /** Update an array of window functions. */ - protected final def update(input: InternalRow): Unit = { - var i = 0 - while (i < numColumns) { - val aggregate = functions(i) - val preparedInput = aggregate.prepareInputParameters(input) - aggregate.update(preparedInput) - i += 1 + /** Row used when there is no valid input. */ + private[this] val emptyRow = new GenericInternalRow(inputSchema.size) + + /** Row used to combine the offset and the current row. */ + private[this] val join = new JoinedRow + + /** Create the projection. */ + private[this] val projection = { + // Collect the expressions and bind them. + val numInputAttributes = inputSchema.size + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { + case e: OffsetWindowFunction => + val input = BindReferences.bindReference(e.input, inputSchema) + if (e.default == null || e.default.foldable && e.default.eval() == null) { + // Without default value. + input + } else { + // With default value. + val default = BindReferences.bindReference(e.default, inputSchema).transform { + // Shift the input reference to its default version. + case BoundReference(o, dataType, nullable) => + BoundReference(o + numInputAttributes, dataType, nullable) + } + org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) + } + case e => + BindReferences.bindReference(e, inputSchema) } + + // Create the projection. + newMutableProjection(boundExpressions, Nil)().target(target) } - /** Evaluate the window functions. */ - protected final def evaluate(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).evaluate() - i += 1 - } + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + input = rows + inputIndex = offset + outputIndex = 0 } - /** Fill a target row with the current window function results. */ - protected final def fill(target: GenericMutableRow, rowIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - target.update(ordinal + i, functions(i).get(rowIndex)) - i += 1 + override def write(): Unit = { + val size = input.size + if (inputIndex >= 0 && inputIndex < size) { + join(input(inputIndex), input(outputIndex)) + } else { + join(emptyRow, input(outputIndex)) } + projection(join) + inputIndex += 1 + outputIndex += 1 } } @@ -496,19 +487,19 @@ private[execution] abstract class WindowFunctionFrame( * The sliding window frame calculates frames with the following SQL form: * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class SlidingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], + target: MutableRow, + processor: AggregateProcessor, lbound: BoundOrdering, - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value greater than the upper bound of the current * output row. */ @@ -518,30 +509,25 @@ private[execution] final class SlidingWindowFunctionFrame( * current output row. */ private[this] var inputLowIndex = 0 - /** Buffer used for storing prepared input for the window functions. */ - private[this] val buffer = new java.util.ArrayDeque[Array[AnyRef]] - /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputHighIndex = 0 inputLowIndex = 0 outputIndex = 0 - buffer.clear() } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Add all rows to the buffer for which the input row value is equal to or less than // the output row upper bound. while (inputHighIndex < input.size && - ubound.compare(input, inputHighIndex, outputIndex) <= 0) { - buffer.offer(prepare(input(inputHighIndex))) + ubound.compare(input, inputHighIndex, outputIndex) <= 0) { inputHighIndex += 1 bufferUpdated = true } @@ -549,25 +535,21 @@ private[execution] final class SlidingWindowFunctionFrame( // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. while (inputLowIndex < inputHighIndex && - lbound.compare(input, inputLowIndex, outputIndex) < 0) { - buffer.pop() + lbound.compare(input, inputLowIndex, outputIndex) < 0) { inputLowIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer.iterator()) - fill(target, outputIndex) + processor.initialize(input.size) + processor.update(input, inputLowIndex, inputHighIndex) + processor.evaluate(target) } // Move to the next row. outputIndex += 1 } - - /** Copy the frame. */ - override def copy: SlidingWindowFunctionFrame = - new SlidingWindowFunctionFrame(ordinal, copyFunctions, lbound, ubound) } /** @@ -578,36 +560,25 @@ private[execution] final class SlidingWindowFunctionFrame( * Its results are the same for each and every row in the partition. This class can be seen as a * special case of a sliding window, but is optimized for the unbound case. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. */ private[execution] final class UnboundedWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) extends WindowFunctionFrame(ordinal, functions) { - - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 + target: MutableRow, + processor: AggregateProcessor) extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() - outputIndex = 0 - val iterator = rows.iterator - while (iterator.hasNext) { - update(iterator.next()) - } - evaluate() + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { + processor.initialize(rows.size) + processor.update(rows, 0, rows.size) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - fill(target, outputIndex) - outputIndex += 1 + override def write(): Unit = { + // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate + // for each row. + processor.evaluate(target) } - - /** Copy the frame. */ - override def copy: UnboundedWindowFunctionFrame = - new UnboundedWindowFunctionFrame(ordinal, copyFunctions) } /** @@ -620,58 +591,53 @@ private[execution] final class UnboundedWindowFunctionFrame( * is not the case when there is no lower bound, given the additive nature of most aggregates * streaming updates and partial evaluation suffice and no buffering is needed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class UnboundedPrecedingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + target: MutableRow, + processor: AggregateProcessor, + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ + * output row. */ private[this] var inputIndex = 0 /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputIndex = 0 outputIndex = 0 + processor.initialize(input.size) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Add all rows to the aggregates for which the input row value is equal to or less than // the output row upper bound. while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) { - update(input(inputIndex)) + processor.update(input(inputIndex)) inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluate() - fill(target, outputIndex) + processor.evaluate(target) } // Move to the next row. outputIndex += 1 } - - /** Copy the frame. */ - override def copy: UnboundedPrecedingWindowFunctionFrame = - new UnboundedPrecedingWindowFunctionFrame(ordinal, copyFunctions, ubound) } /** @@ -686,45 +652,34 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( * buffer and must do full recalculation after each row. Reverse iteration would be possible, if * the communitativity of the used window functions can be guaranteed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. */ private[execution] final class UnboundedFollowingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - lbound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { - - /** Buffer used for storing prepared input for the window functions. */ - private[this] var buffer: Array[Array[AnyRef]] = _ + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: ArrayBuffer[InternalRow] = null /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ + * current output row. */ private[this] var inputIndex = 0 /** Index of the row we are currently writing. */ private[this] var outputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: ArrayBuffer[InternalRow]): Unit = { input = rows inputIndex = 0 outputIndex = 0 - val size = input.size - buffer = Array.ofDim(size) - var i = 0 - while (i < size) { - buffer(i) = prepare(input(i)) - i += 1 - } - evaluatePrepared(buffer, 0, buffer.length) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { + override def write(): Unit = { var bufferUpdated = outputIndex == 0 // Drop all rows from the buffer for which the input row value is smaller than @@ -736,15 +691,151 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer, inputIndex, buffer.length) - fill(target, outputIndex) + processor.initialize(input.size) + processor.update(input, inputIndex, input.size) + processor.evaluate(target) } // Move to the next row. outputIndex += 1 } +} + +/** + * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a + * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, + * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying + * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. + * + * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions + * require the size of the partition processed, this value is exposed to them when the processor is + * constructed. + * + * Processing of distinct aggregates is currently not supported. + * + * The implementation is split into an object which takes care of construction, and a the actual + * processor class. + */ +private[execution] object AggregateProcessor { + def apply(functions: Array[Expression], + ordinal: Int, + inputAttributes: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): + AggregateProcessor = { + val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] + val initialValues = mutable.Buffer.empty[Expression] + val updateExpressions = mutable.Buffer.empty[Expression] + val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) + val imperatives = mutable.Buffer.empty[ImperativeAggregate] + + // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to + // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. + val trackPartitionSize = functions.exists(_.isInstanceOf[SizeBasedWindowFunction]) + if (trackPartitionSize) { + aggBufferAttributes += SizeBasedWindowFunction.n + initialValues += NoOp + updateExpressions += NoOp + } + + // Add an AggregateFunction to the AggregateProcessor. + functions.foreach { + case agg: DeclarativeAggregate => + aggBufferAttributes ++= agg.aggBufferAttributes + initialValues ++= agg.initialValues + updateExpressions ++= agg.updateExpressions + evaluateExpressions += agg.evaluateExpression + case agg: ImperativeAggregate => + val offset = aggBufferAttributes.size + val imperative = BindReferences.bindReference(agg + .withNewInputAggBufferOffset(offset) + .withNewMutableAggBufferOffset(offset), + inputAttributes) + imperatives += imperative + aggBufferAttributes ++= imperative.aggBufferAttributes + val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) + initialValues ++= noOps + updateExpressions ++= noOps + evaluateExpressions += imperative + case other => + sys.error(s"Unsupported Aggregate Function: $other") + } + + // Create the projections. + val initialProjection = newMutableProjection( + initialValues, + Seq(SizeBasedWindowFunction.n))() + val updateProjection = newMutableProjection( + updateExpressions, + aggBufferAttributes ++ inputAttributes)() + val evaluateProjection = newMutableProjection( + evaluateExpressions, + aggBufferAttributes)() + + // Create the processor + new AggregateProcessor( + aggBufferAttributes.toArray, + initialProjection, + updateProjection, + evaluateProjection, + imperatives.toArray, + trackPartitionSize) + } +} + +/** + * This class manages the processing of a number of aggregate functions. See the documentation of + * the object for more information. + */ +private[execution] final class AggregateProcessor( + private[this] val bufferSchema: Array[AttributeReference], + private[this] val initialProjection: MutableProjection, + private[this] val updateProjection: MutableProjection, + private[this] val evaluateProjection: MutableProjection, + private[this] val imperatives: Array[ImperativeAggregate], + private[this] val trackPartitionSize: Boolean) { + + private[this] val join = new JoinedRow + private[this] val numImperatives = imperatives.length + private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) + initialProjection.target(buffer) + updateProjection.target(buffer) + + /** Create the initial state. */ + def initialize(size: Int): Unit = { + // Some initialization expressions are dependent on the partition size so we have to + // initialize the size before initializing all other fields, and we have to pass the buffer to + // the initialization projection. + if (trackPartitionSize) { + buffer.setInt(0, size) + } + initialProjection(buffer) + var i = 0 + while (i < numImperatives) { + imperatives(i).initialize(buffer) + i += 1 + } + } + + /** Update the buffer. */ + def update(input: InternalRow): Unit = { + updateProjection(join(buffer, input)) + var i = 0 + while (i < numImperatives) { + imperatives(i).update(buffer, input) + i += 1 + } + } + + /** Bulk update the given buffer. */ + def update(input: ArrayBuffer[InternalRow], begin: Int, end: Int): Unit = { + var i = begin + while (i < end) { + update(input(i)) + i += 1 + } + } - /** Copy the frame. */ - override def copy: UnboundedFollowingWindowFunctionFrame = - new UnboundedFollowingWindowFunctionFrame(ordinal, copyFunctions, lbound) + /** Evaluate buffer. */ + def evaluate(target: MutableRow): Unit = + evaluateProjection.target(target)(buffer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 893e800a6143..9397fb84105a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -140,57 +140,7 @@ class WindowSpec private[sql]( * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ private[sql] def withAggregate(aggregate: Column): Column = { - val windowExpr = aggregate.expr match { - // First, we check if we get an aggregate function without the DISTINCT keyword. - // Right now, we do not support using a DISTINCT aggregate function as a - // window function. - case AggregateExpression(aggregateFunction, _, isDistinct) if !isDistinct => - aggregateFunction match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(children) => WindowExpression( - UnresolvedWindowFunction("count", children), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction( - "first_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction( - "last_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case x => - throw new UnsupportedOperationException(s"$x is not supported in a window operation.") - } - - case AggregateExpression(aggregateFunction, _, isDistinct) if isDistinct => - throw new UnsupportedOperationException( - s"Distinct aggregate function ${aggregateFunction} is not supported " + - s"in window operation.") - - case wf: WindowFunction => - WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - - case x => - throw new UnsupportedOperationException(s"$x is not supported in a window operation.") - } - - new Column(windowExpr) + val spec = WindowSpecDefinition(partitionSpec, orderSpec, frame) + new Column(WindowExpression(aggregate.expr, spec)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e79defbbbdee..65733dcf83e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -577,7 +577,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def cume_dist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + def cume_dist(): Column = withExpr { new CumeDist } /** * @group window_funcs @@ -597,7 +597,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def dense_rank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } + def dense_rank(): Column = withExpr { new DenseRank } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -648,7 +648,7 @@ object functions extends LegacyFunctions { * @since 1.4.0 */ def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr { - UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + Lag(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -700,7 +700,7 @@ object functions extends LegacyFunctions { * @since 1.4.0 */ def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr { - UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + Lead(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -713,7 +713,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.4.0 */ - def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) } + def ntile(n: Int): Column = withExpr { new NTile(Literal(n)) } /** * @group window_funcs @@ -735,7 +735,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def percent_rank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } + def percent_rank(): Column = withExpr { new PercentRank } /** * Window function: returns the rank of rows within a window partition. @@ -750,7 +750,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.4.0 */ - def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) } + def rank(): Column = withExpr { new Rank } /** * @group window_funcs @@ -765,7 +765,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def row_number(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } + def row_number(): Column = withExpr { RowNumber() } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala similarity index 57% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index 2c98f1c3cc49..b50d7604e0ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.hive +package org.apache.spark.sql -import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.types.{DataType, LongType, StructType} -class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { - import hiveContext.implicits._ - import hiveContext.sql +class DataFrameWindowSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") @@ -55,10 +54,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lead(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) } test("lag") { @@ -68,10 +64,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) } test("lead with default value") { @@ -81,10 +74,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - sql( - """SELECT - | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) } test("lag with default value") { @@ -94,10 +84,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) } test("rank functions in unspecific window") { @@ -112,78 +99,52 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { count("key").over(Window.partitionBy("value").orderBy("key")), sum("key").over(Window.partitionBy("value").orderBy("key")), ntile(2).over(Window.partitionBy("value").orderBy("key")), - rowNumber().over(Window.partitionBy("value").orderBy("key")), - denseRank().over(Window.partitionBy("value").orderBy("key")), + row_number().over(Window.partitionBy("value").orderBy("key")), + dense_rank().over(Window.partitionBy("value").orderBy("key")), rank().over(Window.partitionBy("value").orderBy("key")), - cumeDist().over(Window.partitionBy("value").orderBy("key")), - percentRank().over(Window.partitionBy("value").orderBy("key"))), - sql( - s"""SELECT - |key, - |max(key) over (partition by value order by key), - |min(key) over (partition by value order by key), - |avg(key) over (partition by value order by key), - |count(key) over (partition by value order by key), - |sum(key) over (partition by value order by key), - |ntile(2) over (partition by value order by key), - |row_number() over (partition by value order by key), - |dense_rank() over (partition by value order by key), - |rank() over (partition by value order by key), - |cume_dist() over (partition by value order by key), - |percent_rank() over (partition by value order by key) - |FROM window_table""".stripMargin).collect()) + cume_dist().over(Window.partitionBy("value").orderBy("key")), + percent_rank().over(Window.partitionBy("value").orderBy("key"))), + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) - | FROM window_table""".stripMargin).collect()) + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) } - test("aggregation and range betweens") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + test("aggregation and range between") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) - | FROM window_table""".stripMargin).collect()) + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), + Row(2.0d), Row(2.0d))) } - test("aggregation and rows betweens with unbounded") { + test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", - last("value").over( + last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), - last("value").over( + last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), - last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) - | FROM window_table""".stripMargin).collect()) + last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), + Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), + Row(4, 4, 4, 4))) } - test("aggregation and range betweens with unbounded") { + test("aggregation and range between with unbounded") { val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( @@ -200,18 +161,12 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") ), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 preceding) == "2", - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) - | FROM window_table""".stripMargin).collect()) + Seq(Row(3, null, 3.0d, 4.0d, 3.0d), + Row(5, false, 4.0d, 5.0d, 5.0d), + Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), + Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), + Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), + Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) } test("reverse sliding range frame") { @@ -254,6 +209,87 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { sum($"value").over(window.rangeBetween(1, Long.MaxValue))), Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("statistical functions") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.partitionBy($"key") + checkAnswer( + df.select( + $"key", + var_pop($"value").over(window), + var_samp($"value").over(window), + approxCountDistinct($"value").over(window)), + Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) + ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) + } + + test("window function with aggregates") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.orderBy() + checkAnswer( + df.groupBy($"key") + .agg( + sum($"value"), + sum(sum($"value")).over(window) - sum($"value")), + Seq(Row("a", 6, 9), Row("b", 9, 6))) + } + + test("window function with udaf") { + val udaf = new UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) + } + val df = Seq( + ("a", 1, 1), + ("a", 1, 5), + ("a", 2, 10), + ("a", 2, -1), + ("b", 4, 7), + ("b", 3, 8), + ("b", 2, 4)) + .toDF("key", "a", "b") + val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L) + checkAnswer( + df.select( + $"key", + $"a", + $"b", + udaf($"a", $"b").over(window)), + Seq( + Row("a", 1, 1, 6), + Row("a", 1, 5, 6), + Row("a", 2, 10, 24), + Row("a", 2, -1, 24), + Row("b", 4, 7, 60), + Row("b", 3, 8, 32), + Row("b", 2, 4, 8))) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 92bb9e6d73af..98bbdf0653c2 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -454,6 +454,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) """.stripMargin, reset = false) + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 15. testExpressions", s""" |select p_mfgr,p_name, p_size, @@ -472,7 +475,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 16. testMultipleWindows", s""" |select p_mfgr,p_name, p_size, @@ -530,6 +533,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte // when running this test suite under Java 7 and 8. // We change the original sql query a little bit for making the test suite passed // under different JDK + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 20. testSTATs", """ |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp @@ -547,7 +553,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |) t lateral view explode(uniq_size) d as uniq_data |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 21. testDISTs", """ |select p_mfgr,p_name, p_size, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 5958777b0d06..0eeb62ca2cb3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -476,7 +476,6 @@ class HiveContext private[hive]( catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUDFs :: - ResolveHiveWindowFunction :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 091caab921fe..da41b659e3fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -353,6 +353,14 @@ private[hive] object HiveQl extends Logging { } /** Extractor for matching Hive's AST Tokens. */ + private[hive] case class Token(name: String, children: Seq[ASTNode]) extends Node { + def getName(): String = name + def getChildren(): java.util.List[Node] = { + val col = new java.util.ArrayList[Node](children.size) + children.foreach(col.add(_)) + col + } + } object Token { /** @return matches of the form (tokenName, children). */ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { @@ -360,6 +368,7 @@ private[hive] object HiveQl extends Logging { CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + case t: Token => Some((t.name, t.children)) case _ => None } } @@ -1617,17 +1626,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) /* Window Functions */ - case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = UnresolvedWindowFunction(name, args.map(nodeToExpr)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - // Safe to use Literal(1)? - val function = UnresolvedWindowFunction(name, Literal(1) :: Nil) + case Token(name, args :+ Token("TOK_WINDOWSPEC", spec)) => + val function = nodeToExpr(Token(name, args)) nodesToWindowSpecification(spec) match { case reference: WindowSpecReference => UnresolvedWindowExpression(function, reference) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 2e8c026259ef..a1787fc92d6d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -260,230 +260,6 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr } } -/** - * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]]. - */ -private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { - private def shouldResolveFunction( - unresolvedWindowFunction: UnresolvedWindowFunction, - windowSpec: WindowSpecDefinition): Boolean = { - unresolvedWindowFunction.childrenResolved && windowSpec.childrenResolved - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p: LogicalPlan if !p.childrenResolved => p - - // We are resolving WindowExpressions at here. When we get here, we have already - // replaced those WindowSpecReferences. - case p: LogicalPlan => - p transformExpressions { - // We will not start to resolve the function unless all arguments are resolved - // and all expressions in window spec are fixed. - case WindowExpression( - u @ UnresolvedWindowFunction(name, children), - windowSpec: WindowSpecDefinition) if shouldResolveFunction(u, windowSpec) => - // First, let's find the window function info. - val windowFunctionInfo: WindowFunctionInfo = - Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"Couldn't find window function $name")) - - // Get the class of this function. - // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use - // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. - val functionClass = windowFunctionInfo.getFunctionClass() - val newChildren = - // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit - // input parameters and requires implicit parameters, which - // are expressions in Order By clause. - if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) { - if (children.nonEmpty) { - throw new AnalysisException(s"$name does not take input parameters.") - } - windowSpec.orderSpec.map(_.child) - } else { - children - } - - // If the class is UDAF, we need to use UDAFBridge. - val isUDAFBridgeRequired = - if (classOf[UDAF].isAssignableFrom(functionClass)) { - true - } else { - false - } - - // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of - // HiveWindowFunction. - val windowFunction = - HiveWindowFunction( - new HiveFunctionWrapper(functionClass.getName), - windowFunctionInfo.isPivotResult, - isUDAFBridgeRequired, - newChildren) - - // Second, check if the specified window function can accept window definition. - windowSpec.frameSpecification match { - case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow => - // This Hive window function does not support user-speficied window frame. - throw new AnalysisException( - s"Window function $name does not take a frame specification.") - case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow && - windowFunctionInfo.isPivotResult => - // These two should not be true at the same time when a window frame is defined. - // If so, throw an exception. - throw new AnalysisException(s"Could not handle Hive window function $name because " + - s"it supports both a user specified window frame and pivot result.") - case _ => // OK - } - // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs - // a window frame specification to work. - val newWindowSpec = windowSpec.frameSpecification match { - case UnspecifiedFrame => - val newWindowFrame = - SpecifiedWindowFrame.defaultWindowFrame( - windowSpec.orderSpec.nonEmpty, - windowFunctionInfo.isSupportsWindow) - WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame) - case _ => windowSpec - } - - // Finally, we create a WindowExpression with the resolved window function and - // specified window spec. - WindowExpression(windowFunction, newWindowSpec) - } - } -} - -/** - * A [[WindowFunction]] implementation wrapping Hive's window function. - * @param funcWrapper The wrapper for the Hive Window Function. - * @param pivotResult If it is true, the Hive function will return a list of values representing - * the values of the added columns. Otherwise, a single value is returned for - * current row. - * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's - * createFunction is UDAF, we need to use GenericUDAFBridge to wrap - * it as a GenericUDAFResolver2. - * @param children Input parameters. - */ -private[hive] case class HiveWindowFunction( - funcWrapper: HiveFunctionWrapper, - pivotResult: Boolean, - isUDAFBridgeRequired: Boolean, - children: Seq[Expression]) extends WindowFunction - with HiveInspectors with Unevaluable { - - // Hive window functions are based on GenericUDAFResolver2. - type UDFType = GenericUDAFResolver2 - - @transient - protected lazy val resolver: GenericUDAFResolver2 = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[GenericUDAFResolver2]() - } - - @transient - protected lazy val inputInspectors = children.map(toInspector).toArray - - // The GenericUDAFEvaluator used to evaluate the window function. - @transient - protected lazy val evaluator: GenericUDAFEvaluator = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) - resolver.getEvaluator(parameterInfo) - } - - // The object inspector of values returned from the Hive window function. - @transient - protected lazy val returnInspector = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - override val dataType: DataType = - if (!pivotResult) { - inspectorToDataType(returnInspector) - } else { - // If pivotResult is true, we should take the element type out as the data type of this - // function. - inspectorToDataType(returnInspector) match { - case ArrayType(dt, _) => dt - case _ => - sys.error( - s"error resolve the data type of window function ${funcWrapper.functionClassName}") - } - } - - override def nullable: Boolean = true - - @transient - lazy val inputProjection = new InterpretedProjection(children) - - @transient - private var hiveEvaluatorBuffer: AggregationBuffer = _ - // Output buffer. - private var outputBuffer: Any = _ - - @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - - override def init(): Unit = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - // Reset the hiveEvaluatorBuffer and outputPosition - override def reset(): Unit = { - // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber. - // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init. - // However, RowNumberBuffer.init does not really reset this buffer. - hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer - evaluator.reset(hiveEvaluatorBuffer) - } - - override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap( - inputProjection(input), - inputInspectors, - new Array[AnyRef](children.length), - inputDataTypes) - } - - // Add input parameters for a single row. - override def update(input: AnyRef): Unit = { - evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) - } - - override def batchUpdate(inputs: Array[AnyRef]): Unit = { - var i = 0 - while (i < inputs.length) { - evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]]) - i += 1 - } - } - - override def evaluate(): Unit = { - outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector) - } - - override def get(index: Int): Any = { - if (!pivotResult) { - // if pivotResult is false, we will get a single value for all rows in the frame. - outputBuffer - } else { - // if pivotResult is true, we will get a ArrayData having the same size with the size - // of the window frame. At here, we will return the result at the position of - // index in the output buffer. - outputBuffer.asInstanceOf[ArrayData].get(index, dataType) - } - } - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - override def newInstance(): WindowFunction = - new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) -} - /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * [[Generator]]. Note that the semantics of Generators do not allow diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala new file mode 100644 index 000000000000..c05dbfd7608d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} +import org.apache.spark.sql.test.SQLTestUtils + +/** + * This suite contains a couple of Hive window tests which fail in the typical setup due to tiny + * numerical differences or due semantic differences between Hive and Spark. + */ +class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + override def beforeAll(): Unit = { + sql("DROP TABLE IF EXISTS part") + sql( + """ + |CREATE TABLE part( + | p_partkey INT, + | p_name STRING, + | p_mfgr STRING, + | p_brand STRING, + | p_type STRING, + | p_size INT, + | p_container STRING, + | p_retailprice DOUBLE, + | p_comment STRING) + """.stripMargin) + val testData1 = TestHive.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + sql( + s""" + |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part + """.stripMargin) + } + + override def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS part") + } + + test("windowing.q -- 15. testExpressions") { + // Moved because: + // - Spark uses a different default stddev (sample instead of pop) + // - Tiny numerical differences in stddev results. + // - Different StdDev behavior when n=1 (NaN instead of 0) + checkAnswer(sql(s""" + |select p_mfgr,p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |cume_dist() over(distribute by p_mfgr sort by p_name) as cud, + |percent_rank() over(distribute by p_mfgr sort by p_name) as pr, + |ntile(3) over(distribute by p_mfgr sort by p_name) as nt, + |count(p_size) over(distribute by p_mfgr sort by p_name) as ca, + |avg(p_size) over(distribute by p_mfgr sort by p_name) as avg, + |stddev(p_size) over(distribute by p_mfgr sort by p_name) as st, + |first_value(p_size % 5) over(distribute by p_mfgr sort by p_name) as fv, + |last_value(p_size) over(distribute by p_mfgr sort by p_name) as lv, + |first_value(p_size) over w1 as fvW1 + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17), + Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2))) + // scalastyle:on + } + + test("windowing.q -- 20. testSTATs") { + // Moved because: + // - Spark uses a different default stddev/variance (sample instead of pop) + // - Tiny numerical differences in aggregation results. + checkAnswer(sql(""" + |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + |from ( + |select p_mfgr,p_name, p_size, + |stddev_pop(p_retailprice) over w1 as sdev, + |stddev_pop(p_retailprice) over w1 as sdev_pop, + |collect_set(p_size) over w1 as uniq_size, + |var_pop(p_retailprice) over w1 as var, + |corr(p_size, p_retailprice) over w1 as cor, + |covar_pop(p_size, p_retailprice) over w1 as covarp + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + |) t lateral view explode(uniq_size) d as uniq_data + |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 2, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 6, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 34, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 2, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 34, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 2, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 6, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 28, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 34, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 2, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 6, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 28, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 34, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 42, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 6, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 28, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 34, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 42, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 6, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 28, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 42, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 2, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 14, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 40, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 2, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 14, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 25, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 40, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 2, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 14, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 18, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 25, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 40, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 2, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 18, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 25, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 40, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 2, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 18, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 25, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 14, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 17, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 19, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 1, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 14, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 17, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 19, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 1, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 14, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 17, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 19, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 45, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 1, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 14, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 19, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 45, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 1, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 19, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 45, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 10, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 27, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 39, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 7, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 10, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 27, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 39, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 7, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 10, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 12, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 27, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 39, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 7, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 12, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 27, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 39, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 7, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 12, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 27, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 2, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 6, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 31, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 2, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 6, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 31, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 46, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 2, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 6, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 23, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 31, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 46, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 2, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 6, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 23, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 46, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 2, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 23, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666))) + // scalastyle:on + } +} From cdde93d45401a2f7a97a46b25387c297dcf21d2f Mon Sep 17 00:00:00 2001 From: Jo Voordeckers Date: Mon, 16 Nov 2015 17:18:44 -0800 Subject: [PATCH 1824/1824] Messos scheduler does not respect all args from the Submit request --- .../scheduler/cluster/mesos/MesosClusterScheduler.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index a6d9374eb9e8..4a21a779d2ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -423,6 +423,10 @@ private[spark] class MesosClusterScheduler( "--driver-cores", desc.cores.toString, "--driver-memory", s"${desc.mem}M") + val replicatedOptionsBlacklist = Set( + "spark.jars" // Avoids duplicate classes in classpath + ) + // Assume empty main class means we're running python if (!desc.command.mainClass.equals("")) { options ++= Seq("--class", desc.command.mainClass) @@ -440,6 +444,9 @@ private[spark] class MesosClusterScheduler( .mkString(",") options ++= Seq("--py-files", formattedFiles) } + desc.schedulerProperties + .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } + .foreach { case (key, value) => options ++= Seq("--conf", Seq(key, "=\"", value, "\"").mkString("")) } options }